diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2022-04-04 10:23:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-04 10:23:04 +0200 |
commit | 0afd84d744a472903d427e3c7ae32e55fdd7b9a7 (patch) | |
tree | f9e2ef9b2155fae13591a01bb66c2ea67ce80d18 /test | |
parent | 82452e0f5f6dae05c7a1cc384e7a1fb62ae7e0d5 (diff) |
[WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT (#360)
* add losses and plan computations and exmaple for dual oiptimization
* pep8
* add nice exmaple
* update awesome example stochasti dual
* add all tests
* pep8 + speedup exmaple
* add release info
Diffstat (limited to 'test')
-rw-r--r-- | test/test_stochastic.py | 115 |
1 files changed, 114 insertions, 1 deletions
diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 736df32..2b5c0fb 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library. """ -# Author: Kilian Fatras <kilian.fatras@gmail.com> +# Authors: Kilian Fatras <kilian.fatras@gmail.com> +# Rémi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License @@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn(): G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) np.testing.assert_allclose( G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + + +def test_loss_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 + + +def test_loss_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 |