summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/lp/__init__.py6
-rw-r--r--test/test_ot.py16
2 files changed, 22 insertions, 0 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 0c92810..f77c3d7 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -109,6 +109,9 @@ def emd(a, b, M, numItermax=100000, log=False):
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \
+ "Dimension mismatch, check dimensions of M with a and b"
+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
result_code_string = check_result(result_code)
if log:
@@ -212,6 +215,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \
+ "Dimension mismatch, check dimensions of M with a and b"
+
if log or return_matrix:
def f(b):
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
diff --git a/test/test_ot.py b/test/test_ot.py
index dacae0a..1343604 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -14,6 +14,22 @@ from ot.datasets import make_1D_gauss as gauss
import pytest
+def test_emd_dimension_mismatch():
+ # test emd and emd2 for simple identity
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples + 1)
+
+ M = ot.dist(x, x)
+
+ np.testing.assert_raises(AssertionError, emd, a, a, M)
+
+ np.testing.assert_raises(AssertionError, emd2, a, a, M)
+
+
def test_emd_emd2():
# test emd and emd2 for simple identity
n = 100