summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-03-29 12:41:43 +0100
committerKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-03-29 12:41:43 +0100
commita2545b5a503c95c9bf07948929b77e9c3f4f28d3 (patch)
tree84bc0c169c1121bdff56e77c2c6cc88a68efba67 /test
parent2384380536e3cc405e4db9f4b31cb48d309f257c (diff)
add empirical sinkhorn and sikhorn divergence functions
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 90eaf27..b890df1 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -1,6 +1,7 @@
"""Tests for module bregman on OT with bregman projections """
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -187,3 +188,59 @@ def test_unmix():
ot.bregman.unmix(a, D, M, M0, h0, reg,
1, alpha=0.01, log=True, verbose=True)
+
+
+def test_empirical_sinkhorn():
+ # test sinkhorn
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+ M = ot.dist(X_s, X_t)
+ M_e = ot.dist(X_s, X_t, metric='euclidean')
+
+ rng = np.random.RandomState(0)
+
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n), (n, 1))
+
+ G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
+ sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+
+ G_e = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
+ sinkhorn_e = ot.sinkhorn(a, b, M_e, 1)
+
+ loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
+ loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_e.sum(1), G_e.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_e.sum(0), G_e.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence():
+ #Test sinkhorn divergence
+ n = 10
+ a = ot.unif(n)
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_s = ot.dist(X_s, X_s)
+ M_t = ot.dist(X_t, X_t)
+
+ emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, 1)
+ sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) -
+ ot.sinkhorn2(b, b, M_t, 1))
+
+ # check constratints
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn