summaryrefslogtreecommitdiff
path: root/test/test_factored.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_factored.py')
-rw-r--r--test/test_factored.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/test/test_factored.py b/test/test_factored.py
new file mode 100644
index 0000000..fd2fd01
--- /dev/null
+++ b/test/test_factored.py
@@ -0,0 +1,56 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_factored_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, Ga.sum(1))
+ np.testing.assert_allclose(u, Gb.sum(0))
+
+ Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, Ga.sum(1))
+ np.testing.assert_allclose(u, Gb.sum(0))
+
+
+def test_factored_ot_backends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ xs2 = nx.from_numpy(xs)
+ xt2 = nx.from_numpy(xt)
+ u2 = nx.from_numpy(u)
+
+ Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10)
+
+ # check constraints
+ np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
+ np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))
+
+ Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2)
+
+ # check constraints
+ np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
+ np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))