summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2019-06-27 14:08:21 +0200
committerGitHub <noreply@github.com>2019-06-27 14:08:21 +0200
commita9b8af146648ee2ae50baf46e69e6281f6b279e4 (patch)
treec97a7359bdf7de19d7a1cc325304852622a8f580 /test/test_ot.py
parent2364d56aad650d501753cc93a69ea1b8ddf28b0a (diff)
parent362a7f8fa20cf7ae6f2e36d7e47c7ca9f81d3c51 (diff)
Merge pull request #89 from rtavenar/master
[MRG] EMD and Wasserstein 1D
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py62
1 files changed, 60 insertions, 2 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 7652394..3c4ac11 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -7,6 +7,7 @@
import warnings
import numpy as np
+from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
@@ -37,7 +38,7 @@ def test_emd_emd2():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
@@ -46,6 +47,63 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], [])
+
+
+def test_wass_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+
+ wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
+
+ # check loss is similar
+ np.testing.assert_allclose(np.sqrt(wass), wass1d)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
@@ -60,7 +118,7 @@ def test_emd_empty():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn