diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2020-01-22 15:17:16 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-22 15:17:16 +0100 |
commit | bfde1aca37e0a1ee44b885eb8be624a0f7257bea (patch) | |
tree | 8b06291175fea8267d221ea9c66723e620d22509 /test | |
parent | b6fa567fcb8eaef0699cc8d8ca087ad9c1fb05de (diff) | |
parent | 1b58440457b25aace9dac56aa21144286e60f16e (diff) |
Merge branch 'master' into master
Diffstat (limited to 'test')
-rw-r--r-- | test/test_ot.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 3dd544c..18b6294 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -14,6 +14,22 @@ from ot.datasets import make_1D_gauss as gauss import pytest +def test_emd_dimension_mismatch(): + # test emd and emd2 for dimension mismatch + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples + 1) + + M = ot.dist(x, x) + + np.testing.assert_raises(AssertionError, ot.emd, a, a, M) + + np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + + def test_emd_emd2(): # test emd and emd2 for simple identity n = 100 |