diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 112 |
1 files changed, 104 insertions, 8 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 7652394..47df946 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -7,20 +7,27 @@ import warnings import numpy as np +from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss import pytest -def test_doctest(): - import doctest +def test_emd_dimension_mismatch(): + # test emd and emd2 for dimension mismatch + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples + 1) - # test lp solver - doctest.testmod(ot.lp, verbose=True) + M = ot.dist(x, x) - # test bregman solver - doctest.testmod(ot.bregman, verbose=True) + np.testing.assert_raises(AssertionError, ot.emd, a, a, M) + + np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) def test_emd_emd2(): @@ -37,7 +44,7 @@ def test_emd_emd2(): # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn @@ -46,6 +53,64 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) +def test_emd_1d_emd2_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, ))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0)) + + # check G is similar + np.testing.assert_allclose(G, G_1d) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + with pytest.raises(AssertionError): + ot.emd_1d(u, v, [], []) + + +def test_wass_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + + wass1d = ot.wasserstein_1d(u, v, [], [], p=2.) + + # check loss is similar + np.testing.assert_allclose(np.sqrt(wass), wass1d) + + def test_emd_empty(): # test emd and emd2 for simple identity n = 100 @@ -60,7 +125,7 @@ def test_emd_empty(): # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn @@ -69,6 +134,28 @@ def test_emd_empty(): np.testing.assert_allclose(w, 0) +def test_emd_sparse(): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x2 = rng.randn(n, 2) + + M = ot.dist(x, x2) + + G = ot.emd([], [], M, dense=True) + + Gs = ot.emd([], [], M, dense=False) + + ws = ot.emd2([], [], M, dense=False) + + # check G is the same + np.testing.assert_allclose(G, Gs.todense()) + # check value + np.testing.assert_allclose(Gs.multiply(M).sum(), ws, rtol=1e-6) + + def test_emd2_multi(): n = 500 # nb bins @@ -100,7 +187,12 @@ def test_emd2_multi(): emdn = ot.emd2(a, b, M) ot.toc('multi proc : {} s') + ot.tic() + emdn2 = ot.emd2(a, b, M, dense=False) + ot.toc('multi proc : {} s') + np.testing.assert_allclose(emd1, emdn) + np.testing.assert_allclose(emd1, emdn2, rtol=1e-6) # emd loss multipro proc with log ot.tic() @@ -246,6 +338,10 @@ def test_dual_variables(): np.testing.assert_almost_equal(cost1, log['cost']) check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) + constraint_violation = log['u'][:, None] + log['v'][None, :] - M + + assert constraint_violation.max() < 1e-8 + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) |