summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_backend.py6
-rwxr-xr-xtest/test_partial.py2
-rw-r--r--test/test_solvers.py133
-rw-r--r--test/test_unbalanced.py23
-rw-r--r--test/test_utils.py29
5 files changed, 193 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 311c075..3628f61 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -275,6 +275,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.sqrtm(M)
with pytest.raises(NotImplementedError):
+ nx.kl_div(M, M)
+ with pytest.raises(NotImplementedError):
nx.isfinite(M)
with pytest.raises(NotImplementedError):
nx.array_equal(M, M)
@@ -592,6 +594,10 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("matrix square root")
+ A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("Kullback-Leibler divergence")
+
A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0)
A = nx.isfinite(A)
lst_b.append(nx.to_numpy(A))
diff --git a/test/test_partial.py b/test/test_partial.py
index 33fc259..ae4a1ab 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -79,6 +79,8 @@ def test_partial_wasserstein_lagrange():
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
+ w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True)
+
def test_partial_wasserstein():
diff --git a/test/test_solvers.py b/test/test_solvers.py
new file mode 100644
index 0000000..b792aca
--- /dev/null
+++ b/test/test_solvers.py
@@ -0,0 +1,133 @@
+"""Tests for ot solvers"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+import itertools
+import numpy as np
+import pytest
+
+import ot
+
+
+lst_reg = [None, 1.0]
+lst_reg_type = ['KL', 'entropy', 'L2']
+lst_unbalanced = [None, 0.9]
+lst_unbalanced_type = ['KL', 'L2', 'TV']
+
+
+def assert_allclose_sol(sol1, sol2):
+
+ lst_attr = ['value', 'value_linear', 'plan',
+ 'potential_a', 'potential_b', 'marginal_a', 'marginal_b']
+
+ nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend()
+ nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend()
+
+ for attr in lst_attr:
+ try:
+ np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr)))
+ except NotImplementedError:
+ pass
+
+
+def test_solve(nx):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ # solve unif weights
+ sol0 = ot.solve(M)
+
+ print(sol0)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b)
+
+ # check some attributes
+ sol.potentials
+ sol.sparse_plan
+ sol.marginals
+ sol.status
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b)
+
+ assert_allclose_sol(sol, solb)
+
+ # test not implemented unbalanced and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence')
+
+ # test not implemented reg_type and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence')
+
+
+@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
+def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ try:
+
+ # solve unif weights
+ sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol, solb)
+ except NotImplementedError:
+ pass
+
+
+def test_solve_not_implemented(nx):
+
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+
+ M = ot.dist(x, y)
+
+ # test not implemented and check raise
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='cryptic divergence')
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence')
+
+ # pairs of incompatible divergences
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv')
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index fc40df0..b76d738 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -5,6 +5,7 @@
#
# License: MIT License
+import itertools
import numpy as np
import ot
import pytest
@@ -289,6 +290,28 @@ def test_implemented_methods(nx):
method=method)
+@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2']))
+def test_lbfgsb_unbalanced(nx, reg_div, regm_div):
+
+ np.random.seed(42)
+
+ xs = np.random.randn(5, 2)
+ xt = np.random.randn(6, 2)
+
+ M = ot.dist(xs, xt)
+
+ a = ot.unif(5)
+ b = ot.unif(6)
+
+ G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ np.testing.assert_allclose(G, nx.to_numpy(Gb))
+
+
def test_mm_convergence(nx):
n = 100
rng = np.random.RandomState(42)
diff --git a/test/test_utils.py b/test/test_utils.py
index 19b6365..666c157 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -301,3 +301,32 @@ def test_BaseEstimator():
cl.set_params(bibi=10)
assert cl.first == 'spam again'
+
+
+def test_OTResult():
+
+ res = ot.utils.OTResult()
+
+ # test print
+ print(res)
+
+ # tets get citation
+ print(res.citation)
+
+ lst_attributes = ['a_to_b',
+ 'b_to_a',
+ 'lazy_plan',
+ 'marginal_a',
+ 'marginal_b',
+ 'marginals',
+ 'plan',
+ 'potential_a',
+ 'potential_b',
+ 'potentials',
+ 'sparse_plan',
+ 'status',
+ 'value',
+ 'value_linear']
+ for at in lst_attributes:
+ with pytest.raises(NotImplementedError):
+ getattr(res, at)