summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2019-07-25 08:22:16 +0200
committerGitHub <noreply@github.com>2019-07-25 08:22:16 +0200
commitc64906b70d6059fd56836e8f760f4a836867d51b (patch)
tree96f505830c031022580d32d08d1d2081e9e45204
parent0063cb87a10293a24ad1c9483be121745958c24a (diff)
parenta507556b1901e16351c211e69b38d8d74ac2bc3d (diff)
Merge pull request #97 from hichamjanati/fix_mismatch_error_94
[MRG] Fix mismatch error in stabilized sinkhorn
-rw-r--r--ot/bregman.py10
-rw-r--r--test/test_bregman.py25
2 files changed, 32 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index f39145d..70e4208 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -765,10 +765,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
if log:
- log['logu'] = alpha / reg + np.log(u)
- log['logv'] = beta / reg + np.log(v)
+ if nbb:
+ alpha = alpha[:, None]
+ beta = beta[:, None]
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ log['logu'] = logu
+ log['logv'] = logv
log['alpha'] = alpha + reg * np.log(u)
log['beta'] = beta + reg * np.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 7f4972c..83ebba8 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -254,3 +254,28 @@ def test_empirical_sinkhorn_divergence():
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
np.testing.assert_allclose(
emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+
+def test_stabilized_vs_sinkhorn_multidim():
+ # test if stable version matches sinkhorn
+ # for multidimensional inputs
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ log=True)
+ G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2)