summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-04-04 14:11:36 +0200
committerKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-04-04 14:11:36 +0200
commit782d9b1ae9d8c0b01e32c2af925ac9b7efa42a70 (patch)
treebf698003156787c40f9d9f320eca7704ebc4aa86 /test/test_bregman.py
parent69186a6f4259d32fecac370f59efe16e2e460d04 (diff)
fix test sinkhorn div
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 58700e2..d5482f7 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -243,11 +243,11 @@ def test_empirical_sinkhorn_divergence():
emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
- emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True)
- sink_div_log, log_s = ot.sinkhorn2(a, b, M, 1)
- sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1)
- sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1)
- sink_div_log = sink_div_log - 1 / 2 * (sink_div_log_a + sink_div_log_b)
+ emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
+ sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
+ sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
+ sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
+ sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
# check constratints
np.testing.assert_allclose(