summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2020-04-02 10:39:55 +0100
committerAdrienCorenflos <adrien.corenflos@gmail.com>2020-04-02 10:39:55 +0100
commit1e2e118e3a30224932ed2f012bb8f9f0f374ef2c (patch)
tree5c157d143025379272294db73e821e7a69575569 /test/test_ot.py
parent592f933085d5b521a440eb91eccc283c43732170 (diff)
Fix test
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py18
1 files changed, 6 insertions, 12 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 7afdae3..0f1357f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -7,11 +7,11 @@
import warnings
import numpy as np
+import pytest
from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
-import pytest
def test_emd_dimension_mismatch():
@@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(wass, wass1d_emd2)
# check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
np.testing.assert_allclose(wass_sp, wass1d_euc)
# check constraints
- np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
- np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
+ np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
# check G is similar
np.testing.assert_allclose(G, G_1d)
@@ -91,8 +91,8 @@ def test_emd_1d_emd2_1d():
with pytest.raises(AssertionError):
ot.emd_1d(u, v, [], [])
-def test_emd_1d_emd2_1d_with_weights():
+def test_emd_1d_emd2_1d_with_weights():
# test emd1d gives similar results as emd
n = 20
m = 30
@@ -120,7 +120,7 @@ def test_emd_1d_emd2_1d_with_weights():
np.testing.assert_allclose(wass, wass1d_emd2)
# check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
np.testing.assert_allclose(wass_sp, wass1d_euc)
# check constraints
@@ -128,8 +128,6 @@ def test_emd_1d_emd2_1d_with_weights():
np.testing.assert_allclose(w_v, G.sum(0))
-
-
def test_wass_1d():
# test emd1d gives similar results as emd
n = 20
@@ -173,7 +171,6 @@ def test_emd_empty():
def test_emd_sparse():
-
n = 100
rng = np.random.RandomState(0)
@@ -249,7 +246,6 @@ def test_emd2_multi():
def test_lp_barycenter():
-
a1 = np.array([1.0, 0, 0])[:, None]
a2 = np.array([0, 0, 1.0])[:, None]
@@ -266,7 +262,6 @@ def test_lp_barycenter():
def test_free_support_barycenter():
-
measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
measures_weights = [np.array([1.]), np.array([1.])]
@@ -282,7 +277,6 @@ def test_free_support_barycenter():
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
-
a1 = np.array([1.0, 0, 0])[:, None]
a2 = np.array([0, 0, 1.0])[:, None]