From 0e431c203a66c6d48e6bb1efeda149460472a0f0 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 4 Nov 2021 15:19:57 +0100 Subject: [MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304) * new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8 --- test/test_1d_solver.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_ot.py | 67 +++++++++++++++--------------------- 2 files changed, 120 insertions(+), 40 deletions(-) (limited to 'test') diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 2c470c2..77b1234 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -83,3 +83,96 @@ def test_wasserstein_1d(nx): Xb = nx.from_numpy(X) res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) + + if not str(nx) == 'numpy': + assert res.dtype == xb.dtype + + +def test_emd_1d_emd2_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) + + # check G is similar + np.testing.assert_allclose(G, G_1d, atol=1e-15) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + with pytest.raises(AssertionError): + ot.emd_1d(u, v, [], []) + + +def test_emd1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + + assert emd.dtype == xb.dtype + if not str(nx) == 'numpy': + assert emd2.dtype == xb.dtype diff --git a/test/test_ot.py b/test/test_ot.py index 5bfde1d..dc3930a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,7 +12,6 @@ import pytest import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch -from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch(): @@ -77,6 +76,33 @@ def test_emd2_backends(nx): np.allclose(val, nx.to_numpy(valb)) +def test_emd_emd2_types_devices(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + for tp in nx.__type_list__: + + print(tp.dtype) + + ab = nx.from_numpy(a, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.emd(ab, ab, Mb) + + w = ot.emd2(ab, ab, Mb) + + assert Gb.dtype == Mb.dtype + if not str(nx) == 'numpy': + assert w.dtype == Mb.dtype + + def test_emd2_gradients(): n_samples = 100 n_features = 2 @@ -126,45 +152,6 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) -def test_emd_1d_emd2_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) - - # check G is similar - np.testing.assert_allclose(G, G_1d, atol=1e-15) - - # check AssertionError is raised if called on non 1d arrays - u = np.random.randn(n, 2) - v = np.random.randn(m, 2) - with pytest.raises(AssertionError): - ot.emd_1d(u, v, [], []) - - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 -- cgit v1.2.3