summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-02-02 11:53:12 +0100
committerGitHub <noreply@github.com>2022-02-02 11:53:12 +0100
commita5e0f0d40d5046a6639924347ef97e2ac80ad0c9 (patch)
treedcd35e851ec2cc3f52eedbfa58fb6970664135c9 /test
parent71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (diff)
[MRG] Add weak OT solver (#341)
* add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py13
-rw-r--r--test/test_ot.py2
-rw-r--r--test/test_utils.py18
-rw-r--r--test/test_weak.py54
4 files changed, 75 insertions, 12 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6e90aa4..1419f9b 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -60,7 +60,7 @@ def test_convergence_warning(method):
ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
-def test_not_impemented_method():
+def test_not_implemented_method():
# test sinkhorn
w = 10
n = w ** 2
@@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -940,14 +940,11 @@ def test_screenkhorn(nx):
bb = nx.from_numpy(b)
M_nx = nx.from_numpy(M, type_as=ab)
- # np sinkhorn
- G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
- np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
diff --git a/test/test_ot.py b/test/test_ot.py
index e8e2d97..3e2d845 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -232,7 +232,7 @@ def test_emd2_multi():
# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
- ls = np.arange(20, 500, 20)
+ ls = np.arange(20, 500, 100)
nb = len(ls)
b = np.zeros((n, nb))
for i in range(nb):
diff --git a/test/test_utils.py b/test/test_utils.py
index 8b23c22..5ad167b 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -62,12 +62,12 @@ def test_tic_toc():
import time
ot.tic()
- time.sleep(0.5)
+ time.sleep(0.1)
t = ot.toc()
t2 = ot.toq()
# test timing
- np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1)
+ np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1)
# test toc vs toq
np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1)
@@ -94,10 +94,22 @@ def test_unif():
np.testing.assert_allclose(1, np.sum(u))
-def test_dist():
+def test_unif_backend(nx):
n = 100
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ u = ot.unif(n, type_as=tp)
+
+ np.testing.assert_allclose(1, np.sum(nx.to_numpy(u)), atol=1e-6)
+
+
+def test_dist():
+
+ n = 10
+
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
diff --git a/test/test_weak.py b/test/test_weak.py
new file mode 100644
index 0000000..c4c3278
--- /dev/null
+++ b/test/test_weak.py
@@ -0,0 +1,54 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_weak_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+ # chaeck that identity is recovered
+ G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+
+def test_weak_ot_bakends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G = ot.weak_optimal_transport(xs, xt, u, u)
+
+ xs2 = nx.from_numpy(xs)
+ xt2 = nx.from_numpy(xt)
+ u2 = nx.from_numpy(u)
+
+ G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2)
+
+ np.testing.assert_allclose(nx.to_numpy(G2), G)