From 18502d6861a4977cbade957f2e48eeb8dbb55414 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 21 Jun 2019 11:21:08 +0200 Subject: Sparse G matrix for EMD1d + standard metrics computed without cdist --- test/test_ot.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 7008002..2a2e0a5 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -7,6 +7,7 @@ import warnings import numpy as np +from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss @@ -37,7 +38,7 @@ def test_emd_emd2(): # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn @@ -46,12 +47,13 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) -def test_emd1d(): +def test_emd_1d_emd2_1d(): # test emd1d gives similar results as emd n = 20 m = 30 - u = np.random.randn(n, 1) - v = np.random.randn(m, 1) + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) M = ot.dist(u, v, metric='sqeuclidean') @@ -59,9 +61,20 @@ def test_emd1d(): wass = log["cost"] G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True) wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False) # check loss is similar np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, ))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0)) # check G is similar np.testing.assert_allclose(G, G_1d) @@ -86,7 +99,7 @@ def test_emd_empty(): # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn -- cgit v1.2.3