summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 11:21:08 +0200
committerRomain Tavenard <romain.tavenard@univ-rennes2.fr>2019-06-21 11:21:08 +0200
commit18502d6861a4977cbade957f2e48eeb8dbb55414 (patch)
tree947cf67b5c118ba6eafd72e38ccb0977085767ca /test
parentcada9a3019997e8efb95d96c86985110f1e937b9 (diff)
Sparse G matrix for EMD1d + standard metrics computed without cdist
Diffstat (limited to 'test')
-rw-r--r--test/test_ot.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 7008002..2a2e0a5 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -7,6 +7,7 @@
import warnings
import numpy as np
+from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
@@ -37,7 +38,7 @@ def test_emd_emd2():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
@@ -46,12 +47,13 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
-def test_emd1d():
+def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
m = 30
- u = np.random.randn(n, 1)
- v = np.random.randn(m, 1)
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
M = ot.dist(u, v, metric='sqeuclidean')
@@ -59,9 +61,20 @@ def test_emd1d():
wass = log["cost"]
G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True)
wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False)
# check loss is similar
np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
# check G is similar
np.testing.assert_allclose(G, G_1d)
@@ -86,7 +99,7 @@ def test_emd_empty():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn