From 1e2e118e3a30224932ed2f012bb8f9f0f374ef2c Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Thu, 2 Apr 2020 10:39:55 +0100 Subject: Fix test --- test/test_ot.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index 7afdae3..0f1357f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -7,11 +7,11 @@ import warnings import numpy as np +import pytest from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss -import pytest def test_emd_dimension_mismatch(): @@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d(): np.testing.assert_allclose(wass, wass1d_emd2) # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, ))) + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) np.testing.assert_allclose(wass_sp, wass1d_euc) # check constraints - np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0)) + np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) # check G is similar np.testing.assert_allclose(G, G_1d) @@ -91,8 +91,8 @@ def test_emd_1d_emd2_1d(): with pytest.raises(AssertionError): ot.emd_1d(u, v, [], []) -def test_emd_1d_emd2_1d_with_weights(): +def test_emd_1d_emd2_1d_with_weights(): # test emd1d gives similar results as emd n = 20 m = 30 @@ -120,7 +120,7 @@ def test_emd_1d_emd2_1d_with_weights(): np.testing.assert_allclose(wass, wass1d_emd2) # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) np.testing.assert_allclose(wass_sp, wass1d_euc) # check constraints @@ -128,8 +128,6 @@ def test_emd_1d_emd2_1d_with_weights(): np.testing.assert_allclose(w_v, G.sum(0)) - - def test_wass_1d(): # test emd1d gives similar results as emd n = 20 @@ -173,7 +171,6 @@ def test_emd_empty(): def test_emd_sparse(): - n = 100 rng = np.random.RandomState(0) @@ -249,7 +246,6 @@ def test_emd2_multi(): def test_lp_barycenter(): - a1 = np.array([1.0, 0, 0])[:, None] a2 = np.array([0, 0, 1.0])[:, None] @@ -266,7 +262,6 @@ def test_lp_barycenter(): def test_free_support_barycenter(): - measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] measures_weights = [np.array([1.]), np.array([1.])] @@ -282,7 +277,6 @@ def test_free_support_barycenter(): @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): - a1 = np.array([1.0, 0, 0])[:, None] a2 = np.array([0, 0, 1.0])[:, None] -- cgit v1.2.3