summaryrefslogtreecommitdiff
path: root/test/test_smooth.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-31 12:02:13 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-31 12:02:13 +0200
commitb188381a8e8190958e4e9aca64b7501cf0321406 (patch)
tree3b43ee2e228e357569d478afaba0d99cca864638 /test/test_smooth.py
parente585d64dd08e5367350e70f23e81f9fd2d676a6b (diff)
add semidual
Diffstat (limited to 'test/test_smooth.py')
-rw-r--r--test/test_smooth.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/test/test_smooth.py b/test/test_smooth.py
index e95b3fe..37cc66e 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -20,7 +20,7 @@ def test_smooth_ot_dual():
M = ot.dist(x, x)
- Gl2 = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', stopThr=1e-10)
+ Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
# check constratints
np.testing.assert_allclose(
@@ -39,3 +39,35 @@ def test_smooth_ot_dual():
G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
np.testing.assert_allclose(G, G2, atol=1e-05)
+
+
+def test_smooth_ot_semi_dual():
+
+ # get data
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ # kl regyularisation
+ G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ np.testing.assert_allclose(G, G2, atol=1e-05)