From a2545b5a503c95c9bf07948929b77e9c3f4f28d3 Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Fri, 29 Mar 2019 12:41:43 +0100 Subject: add empirical sinkhorn and sikhorn divergence functions --- test/test_bregman.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) (limited to 'test/test_bregman.py') diff --git a/test/test_bregman.py b/test/test_bregman.py index 90eaf27..b890df1 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1,6 +1,7 @@ """Tests for module bregman on OT with bregman projections """ # Author: Remi Flamary +# Kilian Fatras # # License: MIT License @@ -187,3 +188,59 @@ def test_unmix(): ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, log=True, verbose=True) + + +def test_empirical_sinkhorn(): + # test sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + M = ot.dist(X_s, X_t) + M_e = ot.dist(X_s, X_t, metric='euclidean') + + rng = np.random.RandomState(0) + + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n), (n, 1)) + + G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) + sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + + G_e = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) + sinkhorn_e = ot.sinkhorn(a, b, M_e, 1) + + loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1) + loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + + # check constratints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_e.sum(1), G_e.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_e.sum(0), G_e.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + +def test_empirical_sinkhorn_divergence(): + #Test sinkhorn divergence + n = 10 + a = ot.unif(n) + b = ot.unif(n) + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) + M = ot.dist(X_s, X_t) + M_s = ot.dist(X_s, X_s) + M_t = ot.dist(X_t, X_t) + + emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, 1) + sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - + ot.sinkhorn2(b, b, M_t, 1)) + + # check constratints + np.testing.assert_allclose( + emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn + np.testing.assert_allclose( + emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn -- cgit v1.2.3