summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2020-01-22 15:17:16 +0100
committerGitHub <noreply@github.com>2020-01-22 15:17:16 +0100
commitbfde1aca37e0a1ee44b885eb8be624a0f7257bea (patch)
tree8b06291175fea8267d221ea9c66723e620d22509 /test
parentb6fa567fcb8eaef0699cc8d8ca087ad9c1fb05de (diff)
parent1b58440457b25aace9dac56aa21144286e60f16e (diff)
Merge branch 'master' into master
Diffstat (limited to 'test')
-rw-r--r--test/test_ot.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 3dd544c..18b6294 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -14,6 +14,22 @@ from ot.datasets import make_1D_gauss as gauss
import pytest
+def test_emd_dimension_mismatch():
+ # test emd and emd2 for dimension mismatch
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples + 1)
+
+ M = ot.dist(x, x)
+
+ np.testing.assert_raises(AssertionError, ot.emd, a, a, M)
+
+ np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+
+
def test_emd_emd2():
# test emd and emd2 for simple identity
n = 100