diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2019-07-25 08:22:16 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-07-25 08:22:16 +0200 |
commit | c64906b70d6059fd56836e8f760f4a836867d51b (patch) | |
tree | 96f505830c031022580d32d08d1d2081e9e45204 | |
parent | 0063cb87a10293a24ad1c9483be121745958c24a (diff) | |
parent | a507556b1901e16351c211e69b38d8d74ac2bc3d (diff) |
Merge pull request #97 from hichamjanati/fix_mismatch_error_94
[MRG] Fix mismatch error in stabilized sinkhorn
-rw-r--r-- | ot/bregman.py | 10 | ||||
-rw-r--r-- | test/test_bregman.py | 25 |
2 files changed, 32 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index f39145d..70e4208 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -765,10 +765,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) if log: - log['logu'] = alpha / reg + np.log(u) - log['logv'] = beta / reg + np.log(v) + if nbb: + alpha = alpha[:, None] + beta = beta[:, None] + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + log['logu'] = logu + log['logv'] = logv log['alpha'] = alpha + reg * np.log(u) log['beta'] = beta + reg * np.log(v) log['warmstart'] = (log['alpha'], log['beta']) diff --git a/test/test_bregman.py b/test/test_bregman.py index 7f4972c..83ebba8 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -254,3 +254,28 @@ def test_empirical_sinkhorn_divergence(): emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn + + +def test_stabilized_vs_sinkhorn_multidim(): + # test if stable version matches sinkhorn + # for multidimensional inputs + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon, + method="sinkhorn_stabilized", + log=True) + G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon, + method="sinkhorn", log=True) + + np.testing.assert_allclose(G, G2) |