summaryrefslogtreecommitdiff
path: root/test/test_weak.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52 /test/test_weak.py
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'test/test_weak.py')
-rw-r--r--test/test_weak.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/test/test_weak.py b/test/test_weak.py
new file mode 100644
index 0000000..945efb1
--- /dev/null
+++ b/test/test_weak.py
@@ -0,0 +1,52 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_weak_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+ # chaeck that identity is recovered
+ G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+
+def test_weak_ot_bakends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G = ot.weak_optimal_transport(xs, xt, u, u)
+
+ xs2, xt2, u2 = nx.from_numpy(xs, xt, u)
+
+ G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2)
+
+ np.testing.assert_allclose(nx.to_numpy(G2), G)