diff options
author | Kilian Fatras <kilianfatras@Kilians-MacBook-Air.local> | 2019-04-04 13:45:33 +0200 |
---|---|---|
committer | Kilian Fatras <kilianfatras@Kilians-MacBook-Air.local> | 2019-04-04 13:45:33 +0200 |
commit | 780bdfee3c622698dc9b18a02fa06381314aa56d (patch) | |
tree | f004296f943e6cce23e26c2e67b27702e0867027 /test/test_bregman.py | |
parent | 7c02007919596dedf9d4555737900e717c3d31a8 (diff) |
fix log in sinkhorn div and add log tests
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 0ebd546..68d3595 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -204,6 +204,9 @@ def test_empirical_sinkhorn(): G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski') sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) @@ -216,6 +219,10 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose( sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian np.testing.assert_allclose( sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian @@ -237,8 +244,11 @@ def test_empirical_sinkhorn_divergence(): sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - ot.sinkhorn2(b, b, M_t, 1)) + emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True) + sinkhorn_div_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + # check constratints np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( - emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn + emp_sinkhorn_div_log, sinkhorn_div_log, atol=1e-05) # cf conv emp sinkhorn |