summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-04-04 13:45:33 +0200
committerKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-04-04 13:45:33 +0200
commit780bdfee3c622698dc9b18a02fa06381314aa56d (patch)
treef004296f943e6cce23e26c2e67b27702e0867027 /test
parent7c02007919596dedf9d4555737900e717c3d31a8 (diff)
fix log in sinkhorn div and add log tests
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 0ebd546..68d3595 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -204,6 +204,9 @@ def test_empirical_sinkhorn():
G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
@@ -216,6 +219,10 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(
sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
np.testing.assert_allclose(
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
@@ -237,8 +244,11 @@ def test_empirical_sinkhorn_divergence():
sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) -
ot.sinkhorn2(b, b, M_t, 1))
+ emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True)
+ sinkhorn_div_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
# check constratints
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
np.testing.assert_allclose(
- emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
+ emp_sinkhorn_div_log, sinkhorn_div_log, atol=1e-05) # cf conv emp sinkhorn