summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-07-23 21:28:30 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-07-23 21:28:30 +0200
commit09f3f640fc46ba4905d5508b704f2e5a90dda295 (patch)
treec0b1e2c3644209e5a6c49d676c2a9884fb0bd115 /test
parent92292231a4d9661c399dbfd97b22d6f7f890f698 (diff)
fix issue 94 + add test
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 7f4972c..83ebba8 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -254,3 +254,28 @@ def test_empirical_sinkhorn_divergence():
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
np.testing.assert_allclose(
emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+
+def test_stabilized_vs_sinkhorn_multidim():
+ # test if stable version matches sinkhorn
+ # for multidimensional inputs
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ log=True)
+ G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2)