summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-06-12 17:06:32 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-06-12 17:06:32 +0200
commit11381a7ecc79ef719ee9107167c3adc22b5a3f59 (patch)
treefd2b3b7c4ae59bc4050e6b69579f940e2a5a5f18 /test/test_unbalanced.py
parent28b549ef3ef93c01462cd811d6e55c36ae5a76a2 (diff)
integrate comments of jmassich
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 863b6f3..e37498f 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -6,15 +6,19 @@
import numpy as np
import ot
+import pytest
-def test_unbalanced():
+@pytest.mark.parametrize("metric", ["sinkhorn"])
+def test_unbalanced_convergence(method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
a = ot.utils.unif(n)
+
+ # make dists unbalanced
b = ot.utils.unif(n) * 1.5
M = ot.dist(x, x)
@@ -23,7 +27,8 @@ def test_unbalanced():
K = np.exp(- M / epsilon)
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10, log=True)
+ stopThr=1e-10, method=method,
+ log=True)
# check fixed point equations
fi = alpha / (alpha + epsilon)