summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index ba3ef6a..9a4e175 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ # test emd and emd2 for mass mismatch
+ a = ot.utils.unif(n_samples)
b = a.copy()
a[0] = 100
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
+ np.testing.assert_raises(AssertionError, ot.emd2, a, b, M)
def test_emd_backends(nx):