diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_backend.py | 6 | ||||
-rwxr-xr-x | test/test_partial.py | 2 | ||||
-rw-r--r-- | test/test_solvers.py | 133 | ||||
-rw-r--r-- | test/test_unbalanced.py | 23 | ||||
-rw-r--r-- | test/test_utils.py | 29 |
5 files changed, 193 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index 311c075..3628f61 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -275,6 +275,8 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.sqrtm(M) with pytest.raises(NotImplementedError): + nx.kl_div(M, M) + with pytest.raises(NotImplementedError): nx.isfinite(M) with pytest.raises(NotImplementedError): nx.array_equal(M, M) @@ -592,6 +594,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matrix square root") + A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append("Kullback-Leibler divergence") + A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0) A = nx.isfinite(A) lst_b.append(nx.to_numpy(A)) diff --git a/test/test_partial.py b/test/test_partial.py index 33fc259..ae4a1ab 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -79,6 +79,8 @@ def test_partial_wasserstein_lagrange(): w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True) + w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True) + def test_partial_wasserstein(): diff --git a/test/test_solvers.py b/test/test_solvers.py new file mode 100644 index 0000000..b792aca --- /dev/null +++ b/test/test_solvers.py @@ -0,0 +1,133 @@ +"""Tests for ot solvers""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + + +import itertools +import numpy as np +import pytest + +import ot + + +lst_reg = [None, 1.0] +lst_reg_type = ['KL', 'entropy', 'L2'] +lst_unbalanced = [None, 0.9] +lst_unbalanced_type = ['KL', 'L2', 'TV'] + + +def assert_allclose_sol(sol1, sol2): + + lst_attr = ['value', 'value_linear', 'plan', + 'potential_a', 'potential_b', 'marginal_a', 'marginal_b'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +def test_solve(nx): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = ot.dist(x, y) + + # solve unif weights + sol0 = ot.solve(M) + + print(sol0) + + # solve signe weights + sol = ot.solve(M, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + ab, bb, Mb = nx.from_numpy(a, b, M) + solb = ot.solve(M, a, b) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence') + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = ot.dist(x, y) + + try: + + # solve unif weights + sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + # solve signe weights + sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + assert_allclose_sol(sol0, sol) + + # solve in backend + ab, bb, Mb = nx.from_numpy(a, b, M) + solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + assert_allclose_sol(sol, solb) + except NotImplementedError: + pass + + +def test_solve_not_implemented(nx): + + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + + M = ot.dist(x, y) + + # test not implemented and check raise + with pytest.raises(NotImplementedError): + ot.solve(M, reg=1.0, reg_type='cryptic divergence') + with pytest.raises(NotImplementedError): + ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence') + + # pairs of incompatible divergences + with pytest.raises(NotImplementedError): + ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index fc40df0..b76d738 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -5,6 +5,7 @@ # # License: MIT License +import itertools import numpy as np import ot import pytest @@ -289,6 +290,28 @@ def test_implemented_methods(nx): method=method) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_unbalanced(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + np.testing.assert_allclose(G, nx.to_numpy(Gb)) + + def test_mm_convergence(nx): n = 100 rng = np.random.RandomState(42) diff --git a/test/test_utils.py b/test/test_utils.py index 19b6365..666c157 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -301,3 +301,32 @@ def test_BaseEstimator(): cl.set_params(bibi=10) assert cl.first == 'spam again' + + +def test_OTResult(): + + res = ot.utils.OTResult() + + # test print + print(res) + + # tets get citation + print(res.citation) + + lst_attributes = ['a_to_b', + 'b_to_a', + 'lazy_plan', + 'marginal_a', + 'marginal_b', + 'marginals', + 'plan', + 'potential_a', + 'potential_b', + 'potentials', + 'sparse_plan', + 'status', + 'value', + 'value_linear'] + for at in lst_attributes: + with pytest.raises(NotImplementedError): + getattr(res, at) |