diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-07-23 21:28:30 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-07-23 21:28:30 +0200 |
commit | 09f3f640fc46ba4905d5508b704f2e5a90dda295 (patch) | |
tree | c0b1e2c3644209e5a6c49d676c2a9884fb0bd115 /test | |
parent | 92292231a4d9661c399dbfd97b22d6f7f890f698 (diff) |
fix issue 94 + add test
Diffstat (limited to 'test')
-rw-r--r-- | test/test_bregman.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 7f4972c..83ebba8 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -254,3 +254,28 @@ def test_empirical_sinkhorn_divergence(): emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn + + +def test_stabilized_vs_sinkhorn_multidim(): + # test if stable version matches sinkhorn + # for multidimensional inputs + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon, + method="sinkhorn_stabilized", + log=True) + G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon, + method="sinkhorn", log=True) + + np.testing.assert_allclose(G, G2) |