diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-24 11:54:59 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-24 11:54:59 +0200 |
commit | ff104a6dde2d652283f72d7901bbe79dfb8571ed (patch) | |
tree | 786f1cbfc18b5c2904c5faa5b1aed195ace3a591 | |
parent | 01f8c44d3e6dbe129b4b621af370bb6f015ab2b2 (diff) |
add test for emd and emd2
-rw-r--r-- | test/test_ot.py | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 51ee510..6976818 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -3,7 +3,7 @@ import ot import numpy as np -#import pytest +# import pytest def test_doctest(): @@ -17,8 +17,28 @@ def test_doctest(): doctest.testmod(ot.bregman, verbose=True) +def test_emd_emd2(): + # test emd + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G = ot.emd(u, u, M) + + # check G is identity + assert np.allclose(G, np.eye(n) / n) + + w = ot.emd2(u, u, M) + + # check loss=0 + assert np.allclose(w, 0) + + #@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing") -def test_emd_multi(): +def test_emd2_multi(): from ot.datasets import get_1D_gauss as gauss |