summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2021-04-19 14:57:51 +0300
committerGitHub <noreply@github.com>2021-04-19 13:57:51 +0200
commit2a3f2241951ea9cc044b4fba8a382b6ae9630513 (patch)
treec4a07fda0e2ac6495d673df8aba277588bb47783 /test
parent3a2ec71ae7d11aa650a7d3222357885010a9b2c3 (diff)
BUG/DOC FIX - Sinkhorn divergence used the wrong weights, and sinkhorn2 didn't support epsilon_scaling method. (#235)
* FIX: 1. Documentation of loss specific functions 2. Sinkhorn divergence weights handling 3. Sinkhorn2 does not support epsilon scaling, so I removed it (it *should* arguably support it, but this would require a refactoring of the sinkhorn iterates pretty much everywhere, maybe should be done in torch first?) * Had some PEP8 issues Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6aa4e08..331acd3 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -6,9 +6,10 @@
# License: MIT License
import numpy as np
-import ot
import pytest
+import ot
+
def test_sinkhorn():
# test sinkhorn
@@ -257,7 +258,8 @@ def test_empirical_sinkhorn():
def test_empirical_sinkhorn_divergence():
# Test sinkhorn divergence
n = 10
- a = ot.unif(n)
+ a = np.linspace(1, n, n)
+ a /= a.sum()
b = ot.unif(n)
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
@@ -265,16 +267,15 @@ def test_empirical_sinkhorn_divergence():
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
- emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
+ emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True)
sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
-
- # check constratints
+ # check constraints
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
np.testing.assert_allclose(