summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 863b6f3..e37498f 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -6,15 +6,19 @@
import numpy as np
import ot
+import pytest
-def test_unbalanced():
+@pytest.mark.parametrize("metric", ["sinkhorn"])
+def test_unbalanced_convergence(method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
a = ot.utils.unif(n)
+
+ # make dists unbalanced
b = ot.utils.unif(n) * 1.5
M = ot.dist(x, x)
@@ -23,7 +27,8 @@ def test_unbalanced():
K = np.exp(- M / epsilon)
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10, log=True)
+ stopThr=1e-10, method=method,
+ log=True)
# check fixed point equations
fi = alpha / (alpha + epsilon)