summaryrefslogtreecommitdiff
path: root/test/test_weak.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-02-02 11:53:12 +0100
committerGitHub <noreply@github.com>2022-02-02 11:53:12 +0100
commita5e0f0d40d5046a6639924347ef97e2ac80ad0c9 (patch)
treedcd35e851ec2cc3f52eedbfa58fb6970664135c9 /test/test_weak.py
parent71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (diff)
[MRG] Add weak OT solver (#341)
* add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation
Diffstat (limited to 'test/test_weak.py')
-rw-r--r--test/test_weak.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/test/test_weak.py b/test/test_weak.py
new file mode 100644
index 0000000..c4c3278
--- /dev/null
+++ b/test/test_weak.py
@@ -0,0 +1,54 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_weak_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+ # chaeck that identity is recovered
+ G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1))
+ np.testing.assert_allclose(u, G.sum(0))
+
+
+def test_weak_ot_bakends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ G = ot.weak_optimal_transport(xs, xt, u, u)
+
+ xs2 = nx.from_numpy(xs)
+ xt2 = nx.from_numpy(xt)
+ u2 = nx.from_numpy(u)
+
+ G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2)
+
+ np.testing.assert_allclose(nx.to_numpy(G2), G)