From 2dbeeda9308029a8e8db56bed07d48f4d5718efb Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Mon, 14 Jun 2021 13:06:40 +0200 Subject: [MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259) * Add batch implementation of Sinkhorn * Reformat to pep8 and modify parameter * Fix error in batch size * Code review and add test * Fix accidental typo in test_empirical_sinkhorn * Remove whitespace * Edit config.yml --- test/test_bregman.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 7c5162a..9665229 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -2,6 +2,7 @@ # Author: Remi Flamary # Kilian Fatras +# Quang Huy Tran # # License: MIT License @@ -329,6 +330,49 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) +def test_lazy_empirical_sinkhorn(): + # test sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + numIterMax = 1000 + + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n), (n, 1)) + M = ot.dist(X_s, X_t) + M_m = ot.dist(X_s, X_t, metric='minkowski') + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True) + G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + + f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) + sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) + sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + + # check constratints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + def test_empirical_sinkhorn_divergence(): # Test sinkhorn divergence n = 10 -- cgit v1.2.3