summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py112
1 files changed, 104 insertions, 8 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 7652394..47df946 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -7,20 +7,27 @@
import warnings
import numpy as np
+from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
import pytest
-def test_doctest():
- import doctest
+def test_emd_dimension_mismatch():
+ # test emd and emd2 for dimension mismatch
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples + 1)
- # test lp solver
- doctest.testmod(ot.lp, verbose=True)
+ M = ot.dist(x, x)
- # test bregman solver
- doctest.testmod(ot.bregman, verbose=True)
+ np.testing.assert_raises(AssertionError, ot.emd, a, a, M)
+
+ np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
def test_emd_emd2():
@@ -37,7 +44,7 @@ def test_emd_emd2():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
@@ -46,6 +53,64 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ with pytest.raises(AssertionError):
+ ot.emd_1d(u, v, [], [])
+
+
+def test_wass_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+
+ wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
+
+ # check loss is similar
+ np.testing.assert_allclose(np.sqrt(wass), wass1d)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
@@ -60,7 +125,7 @@ def test_emd_empty():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
@@ -69,6 +134,28 @@ def test_emd_empty():
np.testing.assert_allclose(w, 0)
+def test_emd_sparse():
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x2 = rng.randn(n, 2)
+
+ M = ot.dist(x, x2)
+
+ G = ot.emd([], [], M, dense=True)
+
+ Gs = ot.emd([], [], M, dense=False)
+
+ ws = ot.emd2([], [], M, dense=False)
+
+ # check G is the same
+ np.testing.assert_allclose(G, Gs.todense())
+ # check value
+ np.testing.assert_allclose(Gs.multiply(M).sum(), ws, rtol=1e-6)
+
+
def test_emd2_multi():
n = 500 # nb bins
@@ -100,7 +187,12 @@ def test_emd2_multi():
emdn = ot.emd2(a, b, M)
ot.toc('multi proc : {} s')
+ ot.tic()
+ emdn2 = ot.emd2(a, b, M, dense=False)
+ ot.toc('multi proc : {} s')
+
np.testing.assert_allclose(emd1, emdn)
+ np.testing.assert_allclose(emd1, emdn2, rtol=1e-6)
# emd loss multipro proc with log
ot.tic()
@@ -246,6 +338,10 @@ def test_dual_variables():
np.testing.assert_almost_equal(cost1, log['cost'])
check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])
+ constraint_violation = log['u'][:, None] + log['v'][None, :] - M
+
+ assert constraint_violation.max() < 1e-8
+
def check_duality_gap(a, b, M, G, u, v, cost):
cost_dual = np.vdot(a, u) + np.vdot(b, v)