summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-03-11 10:39:03 +0100
committerRémi Flamary <remi.flamary@gmail.com>2019-03-11 10:39:03 +0100
commit42a501c5d839c010bbfa3a4440b43cb4f9775fc7 (patch)
treef885bb6b6edd9b00e02a35b20d1afbcb749e9923 /test
parent90d04e0f9a3e70d76c9a42b9bbc9c6f6a168269c (diff)
add test sinkhorn+log
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 14edaf5..90eaf27 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -81,6 +81,31 @@ def test_sinkhorn_variants():
print(G0, G_green)
+def test_sinkhorn_variants_log():
+ # test sinkhorn
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ Ges, loges = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True)
+ G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Ges, atol=1e-05)
+ np.testing.assert_allclose(G0, Gerr)
+ np.testing.assert_allclose(G0, G_green, atol=1e-5)
+ print(G0, G_green)
+
+
def test_bary():
n_bins = 100 # nb bins