diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-06-12 17:06:32 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-06-12 17:06:32 +0200 |
commit | 11381a7ecc79ef719ee9107167c3adc22b5a3f59 (patch) | |
tree | fd2b3b7c4ae59bc4050e6b69579f940e2a5a5f18 /test/test_unbalanced.py | |
parent | 28b549ef3ef93c01462cd811d6e55c36ae5a76a2 (diff) |
integrate comments of jmassich
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r-- | test/test_unbalanced.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 863b6f3..e37498f 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -6,15 +6,19 @@ import numpy as np import ot +import pytest -def test_unbalanced(): +@pytest.mark.parametrize("metric", ["sinkhorn"]) +def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) a = ot.utils.unif(n) + + # make dists unbalanced b = ot.utils.unif(n) * 1.5 M = ot.dist(x, x) @@ -23,7 +27,8 @@ def test_unbalanced(): K = np.exp(- M / epsilon) G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, - stopThr=1e-10, log=True) + stopThr=1e-10, method=method, + log=True) # check fixed point equations fi = alpha / (alpha + epsilon) |