diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-03-11 10:39:03 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2019-03-11 10:39:03 +0100 |
commit | 42a501c5d839c010bbfa3a4440b43cb4f9775fc7 (patch) | |
tree | f885bb6b6edd9b00e02a35b20d1afbcb749e9923 /test | |
parent | 90d04e0f9a3e70d76c9a42b9bbc9c6f6a168269c (diff) |
add test sinkhorn+log
Diffstat (limited to 'test')
-rw-r--r-- | test/test_bregman.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 14edaf5..90eaf27 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -81,6 +81,31 @@ def test_sinkhorn_variants(): print(G0, G_green) +def test_sinkhorn_variants_log(): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + Ges, loges = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) + Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True) + G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Ges, atol=1e-05) + np.testing.assert_allclose(G0, Gerr) + np.testing.assert_allclose(G0, G_green, atol=1e-5) + print(G0, G_green) + + def test_bary(): n_bins = 100 # nb bins |