summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2022-04-04 10:23:04 +0200
committerGitHub <noreply@github.com>2022-04-04 10:23:04 +0200
commit0afd84d744a472903d427e3c7ae32e55fdd7b9a7 (patch)
treef9e2ef9b2155fae13591a01bb66c2ea67ce80d18 /test
parent82452e0f5f6dae05c7a1cc384e7a1fb62ae7e0d5 (diff)
[WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT (#360)
* add losses and plan computations and exmaple for dual oiptimization * pep8 * add nice exmaple * update awesome example stochasti dual * add all tests * pep8 + speedup exmaple * add release info
Diffstat (limited to 'test')
-rw-r--r--test/test_stochastic.py115
1 files changed, 114 insertions, 1 deletions
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 736df32..2b5c0fb 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library.
"""
-# Author: Kilian Fatras <kilian.fatras@gmail.com>
+# Authors: Kilian Fatras <kilian.fatras@gmail.com>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
@@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn():
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+
+
+def test_loss_dual_entropic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ ot.stochastic.loss_dual_entropic(u, v, xs, xt)
+
+ ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+
+def test_plan_dual_entropic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt)
+
+ assert np.all(nx.to_numpy(G1) >= 0)
+ assert G1.shape[0] == 50
+ assert G1.shape[1] == 40
+
+ G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+ assert np.all(nx.to_numpy(G2) >= 0)
+ assert G2.shape[0] == 50
+ assert G2.shape[1] == 40
+
+
+def test_loss_dual_quadratic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ ot.stochastic.loss_dual_quadratic(u, v, xs, xt)
+
+ ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+
+def test_plan_dual_quadratic(nx):
+
+ nx.seed(0)
+
+ xs = nx.randn(50, 2)
+ xt = nx.randn(40, 2) + 2
+
+ ws = nx.rand(50)
+ ws = ws / nx.sum(ws)
+
+ wt = nx.rand(40)
+ wt = wt / nx.sum(wt)
+
+ u = nx.randn(50)
+ v = nx.randn(40)
+
+ def metric(x, y):
+ return -nx.dot(x, y.T)
+
+ G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt)
+
+ assert np.all(nx.to_numpy(G1) >= 0)
+ assert G1.shape[0] == 50
+ assert G1.shape[1] == 40
+
+ G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric)
+
+ assert np.all(nx.to_numpy(G2) >= 0)
+ assert G2.shape[0] == 50
+ assert G2.shape[1] == 40