diff options
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r-- | test/test_unbalanced.py | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index dfeaad9..e8349d1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn(): G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, method="sinkhorn_stabilized", reg_m=reg_m, - log=True) + log=True, + verbose=True) G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method="sinkhorn", log=True) @@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method): reg_m = 1. q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True) + method=method, log=True, verbose=True) # check fixed point equations fi = reg_m / (reg_m + epsilon) logA = np.log(A + 1e-16) @@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn(): reg_m=reg_m, log=True, tau=100, method="sinkhorn_stabilized", + verbose=True ) q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", @@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn(): q, qstable, atol=1e-05) +def test_wrong_method(): + + n = 10 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + reg_m = 1. + + with pytest.raises(ValueError): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, + method='badmethod', + log=True, + verbose=True) + with pytest.raises(ValueError): + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method='badmethod', + verbose=True) + + def test_implemented_methods(): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] |