summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py183
1 files changed, 88 insertions, 95 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index b7306f6..92f26a7 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -8,13 +8,13 @@ import warnings
import numpy as np
import pytest
-from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
+from ot.backend import torch
-def test_emd_dimension_mismatch():
+def test_emd_dimension_and_mass_mismatch():
# test emd and emd2 for dimension mismatch
n_samples = 100
n_features = 2
@@ -29,122 +29,125 @@ def test_emd_dimension_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ b = a.copy()
+ a[0] = 100
+ np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
-def test_emd_emd2():
- # test emd and emd2 for simple identity
- n = 100
+
+def test_emd_backends(nx):
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- x = rng.randn(n, 2)
- u = ot.utils.unif(n)
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
- M = ot.dist(x, x)
+ M = ot.dist(x, y)
- G = ot.emd(u, u, M)
+ G = ot.emd(a, a, M)
- # check G is identity
- np.testing.assert_allclose(G, np.eye(n) / n)
- # check constraints
- np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
- np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
- w = ot.emd2(u, u, M)
- # check loss=0
- np.testing.assert_allclose(w, 0)
+ Gb = ot.emd(ab, ab, Mb)
+
+ np.allclose(G, nx.to_numpy(Gb))
-def test_emd_1d_emd2_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+def test_emd2_backends(nx):
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- M = ot.dist(u, v, metric='sqeuclidean')
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+ val = ot.emd2(a, a, M)
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
- np.testing.assert_allclose(wass_sp, wass1d_euc)
+ valb = ot.emd2(ab, ab, Mb)
- # check constraints
- np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
- np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
+ np.allclose(val, nx.to_numpy(valb))
+
+
+def test_emd_emd2_types_devices(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ab = nx.from_numpy(a, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
- # check G is similar
- np.testing.assert_allclose(G, G_1d)
+ Gb = ot.emd(ab, ab, Mb)
- # check AssertionError is raised if called on non 1d arrays
- u = np.random.randn(n, 2)
- v = np.random.randn(m, 2)
- with pytest.raises(AssertionError):
- ot.emd_1d(u, v, [], [])
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
-def test_emd_1d_emd2_1d_with_weights():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+
+def test_emd2_gradients():
+ n_samples = 100
+ n_features = 2
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- w_u = rng.uniform(0., 1., n)
- w_u = w_u / w_u.sum()
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
- w_v = rng.uniform(0., 1., m)
- w_v = w_v / w_v.sum()
+ M = ot.dist(x, y)
- M = ot.dist(u, v, metric='sqeuclidean')
+ if torch:
- G, log = ot.emd(w_u, w_v, M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
+ val = ot.emd2(a1, b1, M1)
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
- np.testing.assert_allclose(wass_sp, wass1d_euc)
+ val.backward()
- # check constraints
- np.testing.assert_allclose(w_u, G.sum(1))
- np.testing.assert_allclose(w_v, G.sum(0))
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
-def test_wass_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
+def test_emd_emd2():
+ # test emd and emd2 for simple identity
+ n = 100
rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
- M = ot.dist(u, v, metric='sqeuclidean')
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
+ M = ot.dist(x, x)
- wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
+ G = ot.emd(u, u, M)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
- # check loss is similar
- np.testing.assert_allclose(np.sqrt(wass), wass1d)
+ w = ot.emd2(u, u, M)
+ # check loss=0
+ np.testing.assert_allclose(w, 0)
def test_emd_empty():
@@ -291,17 +294,7 @@ def test_warnings():
print('Computing {} EMD '.format(1))
ot.emd(a, b, M, numItermax=1)
assert "numItermax" in str(w[-1].message)
- assert len(w) == 1
- a[0] = 100
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- assert len(w) == 2
- a[0] = -1
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- assert len(w) == 3
+ #assert len(w) == 1
def test_dual_variables():