summaryrefslogtreecommitdiff
path: root/test/test_stochastic.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_stochastic.py')
-rw-r--r--test/test_stochastic.py115
1 files changed, 114 insertions, 1 deletions
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 736df32..2b5c0fb 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library.
"""
-# Author: Kilian Fatras <kilian.fatras@gmail.com>
+# Authors: Kilian Fatras <kilian.fatras@gmail.com>
+# RĂ©mi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
@@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn():
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+
+
+def test_loss_dual_entropic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ ot.stochastic.loss_dual_entropic(u, v, xs, xt)
+
+ ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+
+def test_plan_dual_entropic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt)
+
+ assert np.all(nx.to_numpy(G1) >= 0)
+ assert G1.shape[0] == 50
+ assert G1.shape[1] == 40
+
+ G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+ assert np.all(nx.to_numpy(G2) >= 0)
+ assert G2.shape[0] == 50
+ assert G2.shape[1] == 40
+
+
+def test_loss_dual_quadratic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ ot.stochastic.loss_dual_quadratic(u, v, xs, xt)
+
+ ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+
+def test_plan_dual_quadratic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt)
+
+ assert np.all(nx.to_numpy(G1) >= 0)
+ assert G1.shape[0] == 50
+ assert G1.shape[1] == 40
+
+ G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+ assert np.all(nx.to_numpy(G2) >= 0)
+ assert G2.shape[0] == 50
+ assert G2.shape[1] == 40