summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
authorHuy Tran <huytran82125@gmail.com>2021-06-14 13:06:40 +0200
committerGitHub <noreply@github.com>2021-06-14 13:06:40 +0200
commit2dbeeda9308029a8e8db56bed07d48f4d5718efb (patch)
tree60fab2738fa0ffc6fc3c2762171d6d8d133e28c0 /test/test_bregman.py
parent982510eb5085a0edd7a00fb96a308854957d32bf (diff)
[MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259)
* Add batch implementation of Sinkhorn * Reformat to pep8 and modify parameter * Fix error in batch size * Code review and add test * Fix accidental typo in test_empirical_sinkhorn * Remove whitespace * Edit config.yml
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 7c5162a..9665229 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -2,6 +2,7 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
@@ -329,6 +330,49 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+def test_lazy_empirical_sinkhorn():
+ # test sinkhorn
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+ numIterMax = 1000
+
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='minkowski')
+
+ f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True)
+ G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
+ sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+
+ f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
+ sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
+ sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
def test_empirical_sinkhorn_divergence():
# Test sinkhorn divergence
n = 10