summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py57
1 files changed, 1 insertions, 56 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 4dfc510..5bfde1d 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -8,11 +8,11 @@ import warnings
import numpy as np
import pytest
-from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
from ot.backend import torch
+from scipy.stats import wasserstein_distance
def test_emd_dimension_and_mass_mismatch():
@@ -165,61 +165,6 @@ def test_emd_1d_emd2_1d():
ot.emd_1d(u, v, [], [])
-def test_emd_1d_emd2_1d_with_weights():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
- rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
-
- w_u = rng.uniform(0., 1., n)
- w_u = w_u / w_u.sum()
-
- w_v = rng.uniform(0., 1., m)
- w_v = w_v / w_v.sum()
-
- M = ot.dist(u, v, metric='sqeuclidean')
-
- G, log = ot.emd(w_u, w_v, M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
-
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
-
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
- np.testing.assert_allclose(wass_sp, wass1d_euc)
-
- # check constraints
- np.testing.assert_allclose(w_u, G.sum(1))
- np.testing.assert_allclose(w_v, G.sum(0))
-
-
-def test_wass_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
- rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
-
- M = ot.dist(u, v, metric='sqeuclidean')
-
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
-
- wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
-
- # check loss is similar
- np.testing.assert_allclose(np.sqrt(wass), wass1d)
-
-
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100