From 2a3f2241951ea9cc044b4fba8a382b6ae9630513 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 19 Apr 2021 14:57:51 +0300 Subject: BUG/DOC FIX - Sinkhorn divergence used the wrong weights, and sinkhorn2 didn't support epsilon_scaling method. (#235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * FIX: 1. Documentation of loss specific functions 2. Sinkhorn divergence weights handling 3. Sinkhorn2 does not support epsilon scaling, so I removed it (it *should* arguably support it, but this would require a refactoring of the sinkhorn iterates pretty much everywhere, maybe should be done in torch first?) * Had some PEP8 issues Co-authored-by: RĂ©mi Flamary --- test/test_bregman.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'test/test_bregman.py') diff --git a/test/test_bregman.py b/test/test_bregman.py index 6aa4e08..331acd3 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -6,9 +6,10 @@ # License: MIT License import numpy as np -import ot import pytest +import ot + def test_sinkhorn(): # test sinkhorn @@ -257,7 +258,8 @@ def test_empirical_sinkhorn(): def test_empirical_sinkhorn_divergence(): # Test sinkhorn divergence n = 10 - a = ot.unif(n) + a = np.linspace(1, n, n) + a /= a.sum() b = ot.unif(n) X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) @@ -265,16 +267,15 @@ def test_empirical_sinkhorn_divergence(): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) + emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) - emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True) + emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True) sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True) sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True) sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True) sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b) - - # check constratints + # check constraints np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( -- cgit v1.2.3