summaryrefslogtreecommitdiff
path: root/test/test_solvers.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-12-15 09:28:01 +0100
committerGitHub <noreply@github.com>2022-12-15 09:28:01 +0100
commit0411ea22a96f9c22af30156b45c16ef39ffb520d (patch)
tree7c131ad804d5b16a8c362c2fe296350a770400df /test/test_solvers.py
parent8490196dcc982c492b7565e1ec4de5f75f006acf (diff)
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test/test_solvers.py')
-rw-r--r--test/test_solvers.py133
1 files changed, 133 insertions, 0 deletions
diff --git a/test/test_solvers.py b/test/test_solvers.py
new file mode 100644
index 0000000..b792aca
--- /dev/null
+++ b/test/test_solvers.py
@@ -0,0 +1,133 @@
+"""Tests for ot solvers"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+import itertools
+import numpy as np
+import pytest
+
+import ot
+
+
+lst_reg = [None, 1.0]
+lst_reg_type = ['KL', 'entropy', 'L2']
+lst_unbalanced = [None, 0.9]
+lst_unbalanced_type = ['KL', 'L2', 'TV']
+
+
+def assert_allclose_sol(sol1, sol2):
+
+ lst_attr = ['value', 'value_linear', 'plan',
+ 'potential_a', 'potential_b', 'marginal_a', 'marginal_b']
+
+ nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend()
+ nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend()
+
+ for attr in lst_attr:
+ try:
+ np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr)))
+ except NotImplementedError:
+ pass
+
+
+def test_solve(nx):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ # solve unif weights
+ sol0 = ot.solve(M)
+
+ print(sol0)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b)
+
+ # check some attributes
+ sol.potentials
+ sol.sparse_plan
+ sol.marginals
+ sol.status
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b)
+
+ assert_allclose_sol(sol, solb)
+
+ # test not implemented unbalanced and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence')
+
+ # test not implemented reg_type and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence')
+
+
+@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
+def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ try:
+
+ # solve unif weights
+ sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol, solb)
+ except NotImplementedError:
+ pass
+
+
+def test_solve_not_implemented(nx):
+
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+
+ M = ot.dist(x, y)
+
+ # test not implemented and check raise
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='cryptic divergence')
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence')
+
+ # pairs of incompatible divergences
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv')