From 11381a7ecc79ef719ee9107167c3adc22b5a3f59 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:06:32 +0200 Subject: integrate comments of jmassich --- test/test_unbalanced.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'test/test_unbalanced.py') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 863b6f3..e37498f 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -6,15 +6,19 @@ import numpy as np import ot +import pytest -def test_unbalanced(): +@pytest.mark.parametrize("metric", ["sinkhorn"]) +def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) a = ot.utils.unif(n) + + # make dists unbalanced b = ot.utils.unif(n) * 1.5 M = ot.dist(x, x) @@ -23,7 +27,8 @@ def test_unbalanced(): K = np.exp(- M / epsilon) G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, - stopThr=1e-10, log=True) + stopThr=1e-10, method=method, + log=True) # check fixed point equations fi = alpha / (alpha + epsilon) -- cgit v1.2.3