From 42a501c5d839c010bbfa3a4440b43cb4f9775fc7 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 11 Mar 2019 10:39:03 +0100 Subject: add test sinkhorn+log --- test/test_bregman.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) (limited to 'test/test_bregman.py') diff --git a/test/test_bregman.py b/test/test_bregman.py index 14edaf5..90eaf27 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -81,6 +81,31 @@ def test_sinkhorn_variants(): print(G0, G_green) +def test_sinkhorn_variants_log(): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + Ges, loges = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) + Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True) + G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Ges, atol=1e-05) + np.testing.assert_allclose(G0, Gerr) + np.testing.assert_allclose(G0, G_green, atol=1e-5) + print(G0, G_green) + + def test_bary(): n_bins = 100 # nb bins -- cgit v1.2.3