summaryrefslogtreecommitdiff
path: root/test/test_factored.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-03-24 14:13:25 +0100
committerGitHub <noreply@github.com>2022-03-24 14:13:25 +0100
commit82452e0f5f6dae05c7a1cc384e7a1fb62ae7e0d5 (patch)
tree051871e3dc63e6bba1d0ecb1df6229796edd33bb /test/test_factored.py
parent767171593f2a98a26b9a39bf110a45085e3b982e (diff)
[MRG] Add factored coupling (#358)
* add gfactored ot * pep8 and add doc * add exmaple for factotred OT * final number of PR * correct test on backends * remove useless loss * better tests
Diffstat (limited to 'test/test_factored.py')
-rw-r--r--test/test_factored.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/test/test_factored.py b/test/test_factored.py
new file mode 100644
index 0000000..fd2fd01
--- /dev/null
+++ b/test/test_factored.py
@@ -0,0 +1,56 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_factored_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, Ga.sum(1))
+ np.testing.assert_allclose(u, Gb.sum(0))
+
+ Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, Ga.sum(1))
+ np.testing.assert_allclose(u, Gb.sum(0))
+
+
+def test_factored_ot_backends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ xs2 = nx.from_numpy(xs)
+ xt2 = nx.from_numpy(xt)
+ u2 = nx.from_numpy(u)
+
+ Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10)
+
+ # check constraints
+ np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
+ np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))
+
+ Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2)
+
+ # check constraints
+ np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
+ np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))