diff options
Diffstat (limited to 'test/test_stochastic.py')
-rw-r--r-- | test/test_stochastic.py | 115 |
1 files changed, 114 insertions, 1 deletions
diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 736df32..2b5c0fb 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library. """ -# Author: Kilian Fatras <kilian.fatras@gmail.com> +# Authors: Kilian Fatras <kilian.fatras@gmail.com> +# RĂ©mi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License @@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn(): G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) np.testing.assert_allclose( G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + + +def test_loss_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 + + +def test_loss_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 |