summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-04 15:19:57 +0100
committerGitHub <noreply@github.com>2021-11-04 15:19:57 +0100
commit0e431c203a66c6d48e6bb1efeda149460472a0f0 (patch)
tree22a447a1dbb1505b18f9e426e1761cf6b328b6eb /test
parent2fe69eb130827560ada704bc25998397c4357821 (diff)
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8
Diffstat (limited to 'test')
-rw-r--r--test/test_1d_solver.py93
-rw-r--r--test/test_ot.py67
2 files changed, 120 insertions, 40 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 2c470c2..77b1234 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -83,3 +83,96 @@ def test_wasserstein_1d(nx):
Xb = nx.from_numpy(X)
res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_wasserstein_1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+
+ print(tp.dtype)
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+
+ if not str(nx) == 'numpy':
+ assert res.dtype == xb.dtype
+
+
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d, atol=1e-15)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ with pytest.raises(AssertionError):
+ ot.emd_1d(u, v, [], [])
+
+
+def test_emd1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+
+ print(tp.dtype)
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+
+ assert emd.dtype == xb.dtype
+ if not str(nx) == 'numpy':
+ assert emd2.dtype == xb.dtype
diff --git a/test/test_ot.py b/test/test_ot.py
index 5bfde1d..dc3930a 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -12,7 +12,6 @@ import pytest
import ot
from ot.datasets import make_1D_gauss as gauss
from ot.backend import torch
-from scipy.stats import wasserstein_distance
def test_emd_dimension_and_mass_mismatch():
@@ -77,6 +76,33 @@ def test_emd2_backends(nx):
np.allclose(val, nx.to_numpy(valb))
+def test_emd_emd2_types_devices(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ for tp in nx.__type_list__:
+
+ print(tp.dtype)
+
+ ab = nx.from_numpy(a, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ Gb = ot.emd(ab, ab, Mb)
+
+ w = ot.emd2(ab, ab, Mb)
+
+ assert Gb.dtype == Mb.dtype
+ if not str(nx) == 'numpy':
+ assert w.dtype == Mb.dtype
+
+
def test_emd2_gradients():
n_samples = 100
n_features = 2
@@ -126,45 +152,6 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
-def test_emd_1d_emd2_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
- rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
-
- M = ot.dist(u, v, metric='sqeuclidean')
-
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
-
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
-
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
- np.testing.assert_allclose(wass_sp, wass1d_euc)
-
- # check constraints
- np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
- np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
-
- # check G is similar
- np.testing.assert_allclose(G, G_1d, atol=1e-15)
-
- # check AssertionError is raised if called on non 1d arrays
- u = np.random.randn(n, 2)
- v = np.random.randn(m, 2)
- with pytest.raises(AssertionError):
- ot.emd_1d(u, v, [], [])
-
-
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100