diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 91 |
1 files changed, 79 insertions, 12 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index f45e4c9..3e953dc 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,12 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss +from ot.backend import get_backend_list, torch +backend_list = get_backend_list() -def test_emd_dimension_mismatch(): + +def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 n_features = 2 @@ -29,6 +32,80 @@ def test_emd_dimension_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + b = a.copy() + a[0] = 100 + np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.emd(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.emd(ab, ab, Mb) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + val = ot.emd2(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + valb = ot.emd2(ab, ab, Mb) + + np.allclose(val, nx.to_numpy(valb)) + + +def test_emd2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.emd2(a1, b1, M1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + def test_emd_emd2(): # test emd and emd2 for simple identity @@ -83,7 +160,7 @@ def test_emd_1d_emd2_1d(): np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) # check G is similar - np.testing.assert_allclose(G, G_1d) + np.testing.assert_allclose(G, G_1d, atol=1e-15) # check AssertionError is raised if called on non 1d arrays u = np.random.randn(n, 2) @@ -292,16 +369,6 @@ def test_warnings(): ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) #assert len(w) == 1 - a[0] = 100 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 2 - a[0] = -1 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 3 def test_dual_variables(): |