diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 183 |
1 files changed, 88 insertions, 95 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index b7306f6..92f26a7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -8,13 +8,13 @@ import warnings import numpy as np import pytest -from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss +from ot.backend import torch -def test_emd_dimension_mismatch(): +def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 n_features = 2 @@ -29,122 +29,125 @@ def test_emd_dimension_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + b = a.copy() + a[0] = 100 + np.testing.assert_raises(AssertionError, ot.emd, a, b, M) -def test_emd_emd2(): - # test emd and emd2 for simple identity - n = 100 + +def test_emd_backends(nx): + n_samples = 100 + n_features = 2 rng = np.random.RandomState(0) - x = rng.randn(n, 2) - u = ot.utils.unif(n) + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) - M = ot.dist(x, x) + M = ot.dist(x, y) - G = ot.emd(u, u, M) + G = ot.emd(a, a, M) - # check G is identity - np.testing.assert_allclose(G, np.eye(n) / n) - # check constraints - np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn - np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) - w = ot.emd2(u, u, M) - # check loss=0 - np.testing.assert_allclose(w, 0) + Gb = ot.emd(ab, ab, Mb) + + np.allclose(G, nx.to_numpy(Gb)) -def test_emd_1d_emd2_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 +def test_emd2_backends(nx): + n_samples = 100 + n_features = 2 rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - M = ot.dist(u, v, metric='sqeuclidean') + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + val = ot.emd2(a, a, M) - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) - np.testing.assert_allclose(wass_sp, wass1d_euc) + valb = ot.emd2(ab, ab, Mb) - # check constraints - np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) + np.allclose(val, nx.to_numpy(valb)) + + +def test_emd_emd2_types_devices(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ab = nx.from_numpy(a, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) - # check G is similar - np.testing.assert_allclose(G, G_1d) + Gb = ot.emd(ab, ab, Mb) - # check AssertionError is raised if called on non 1d arrays - u = np.random.randn(n, 2) - v = np.random.randn(m, 2) - with pytest.raises(AssertionError): - ot.emd_1d(u, v, [], []) + w = ot.emd2(ab, ab, Mb) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) -def test_emd_1d_emd2_1d_with_weights(): - # test emd1d gives similar results as emd - n = 20 - m = 30 + +def test_emd2_gradients(): + n_samples = 100 + n_features = 2 rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - w_u = rng.uniform(0., 1., n) - w_u = w_u / w_u.sum() + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) - w_v = rng.uniform(0., 1., m) - w_v = w_v / w_v.sum() + M = ot.dist(x, y) - M = ot.dist(u, v, metric='sqeuclidean') + if torch: - G, log = ot.emd(w_u, w_v, M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) + val = ot.emd2(a1, b1, M1) - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) - np.testing.assert_allclose(wass_sp, wass1d_euc) + val.backward() - # check constraints - np.testing.assert_allclose(w_u, G.sum(1)) - np.testing.assert_allclose(w_v, G.sum(0)) + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape -def test_wass_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 +def test_emd_emd2(): + # test emd and emd2 for simple identity + n = 100 rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - M = ot.dist(u, v, metric='sqeuclidean') + x = rng.randn(n, 2) + u = ot.utils.unif(n) - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] + M = ot.dist(x, x) - wass1d = ot.wasserstein_1d(u, v, [], [], p=2.) + G = ot.emd(u, u, M) + + # check G is identity + np.testing.assert_allclose(G, np.eye(n) / n) + # check constraints + np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn - # check loss is similar - np.testing.assert_allclose(np.sqrt(wass), wass1d) + w = ot.emd2(u, u, M) + # check loss=0 + np.testing.assert_allclose(w, 0) def test_emd_empty(): @@ -291,17 +294,7 @@ def test_warnings(): print('Computing {} EMD '.format(1)) ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) - assert len(w) == 1 - a[0] = 100 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - assert len(w) == 2 - a[0] = -1 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - assert len(w) == 3 + #assert len(w) == 1 def test_dual_variables(): |