diff options
author | Huy Tran <huytran82125@gmail.com> | 2021-06-14 13:06:40 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-14 13:06:40 +0200 |
commit | 2dbeeda9308029a8e8db56bed07d48f4d5718efb (patch) | |
tree | 60fab2738fa0ffc6fc3c2762171d6d8d133e28c0 /test/test_bregman.py | |
parent | 982510eb5085a0edd7a00fb96a308854957d32bf (diff) |
[MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259)
* Add batch implementation of Sinkhorn
* Reformat to pep8 and modify parameter
* Fix error in batch size
* Code review and add test
* Fix accidental typo in test_empirical_sinkhorn
* Remove whitespace
* Edit config.yml
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 7c5162a..9665229 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -2,6 +2,7 @@ # Author: Remi Flamary <remi.flamary@unice.fr> # Kilian Fatras <kilian.fatras@irisa.fr> +# Quang Huy Tran <quang-huy.tran@univ-ubs.fr> # # License: MIT License @@ -329,6 +330,49 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) +def test_lazy_empirical_sinkhorn(): + # test sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + numIterMax = 1000 + + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n), (n, 1)) + M = ot.dist(X_s, X_t) + M_m = ot.dist(X_s, X_t, metric='minkowski') + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True) + G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + + f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) + sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) + sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + + # check constratints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + def test_empirical_sinkhorn_divergence(): # Test sinkhorn divergence n = 10 |