summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py33
1 files changed, 31 insertions, 2 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index dfeaad9..e8349d1 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn():
G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
method="sinkhorn_stabilized",
reg_m=reg_m,
- log=True)
+ log=True,
+ verbose=True)
G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method="sinkhorn", log=True)
@@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method):
reg_m = 1.
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method=method, log=True)
+ method=method, log=True, verbose=True)
# check fixed point equations
fi = reg_m / (reg_m + epsilon)
logA = np.log(A + 1e-16)
@@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn():
reg_m=reg_m, log=True,
tau=100,
method="sinkhorn_stabilized",
+ verbose=True
)
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method="sinkhorn",
@@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn():
q, qstable, atol=1e-05)
+def test_wrong_method():
+
+ n = 10
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ with pytest.raises(ValueError):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method='badmethod',
+ log=True,
+ verbose=True)
+ with pytest.raises(ValueError):
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method='badmethod',
+ verbose=True)
+
+
def test_implemented_methods():
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']