summaryrefslogtreecommitdiff
path: root/test/test_weak.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_weak.py')
-rw-r--r--test/test_weak.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/test/test_weak.py b/test/test_weak.py
new file mode 100644
index 0000000..945efb1
--- /dev/null
+++ b/test/test_weak.py
@@ -0,0 +1,52 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_weak_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+ # chaeck that identity is recovered
+ G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+
+def test_weak_ot_bakends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G = ot.weak_optimal_transport(xs, xt, u, u)
+
+ xs2, xt2, u2 = nx.from_numpy(xs, xt, u)
+
+ G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2)
+
+ np.testing.assert_allclose(nx.to_numpy(G2), G)