diff options
author | Kilian Fatras <kilianfatras@Kilians-MacBook-Air.local> | 2019-04-04 13:58:50 +0200 |
---|---|---|
committer | Kilian Fatras <kilianfatras@Kilians-MacBook-Air.local> | 2019-04-04 13:58:50 +0200 |
commit | 69186a6f4259d32fecac370f59efe16e2e460d04 (patch) | |
tree | 7577b7e5c6faaf3ce8eebed8b56ce8af1e7614b1 /test | |
parent | 780bdfee3c622698dc9b18a02fa06381314aa56d (diff) |
fix test sinkhorn div
Diffstat (limited to 'test')
-rw-r--r-- | test/test_bregman.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 68d3595..58700e2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -241,11 +241,13 @@ def test_empirical_sinkhorn_divergence(): M_t = ot.dist(X_t, X_t) emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) - sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - - ot.sinkhorn2(b, b, M_t, 1)) + sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True) - sinkhorn_div_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + sink_div_log, log_s = ot.sinkhorn2(a, b, M, 1) + sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1) + sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1) + sink_div_log = sink_div_log - 1 / 2 * (sink_div_log_a + sink_div_log_b) # check constratints np.testing.assert_allclose( |