diff options
-rw-r--r-- | ot/lp/__init__.py | 6 | ||||
-rw-r--r-- | test/test_ot.py | 16 |
2 files changed, 22 insertions, 0 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index bb9829a..eabdd3a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -114,6 +114,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ + "Dimension mismatch, check dimensions of M with a and b" + if dense: G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) else: @@ -226,6 +229,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ + "Dimension mismatch, check dimensions of M with a and b" + if log or return_matrix: def f(b): if dense: diff --git a/test/test_ot.py b/test/test_ot.py index 3dd544c..18b6294 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -14,6 +14,22 @@ from ot.datasets import make_1D_gauss as gauss import pytest +def test_emd_dimension_mismatch(): + # test emd and emd2 for dimension mismatch + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples + 1) + + M = ot.dist(x, x) + + np.testing.assert_raises(AssertionError, ot.emd, a, a, M) + + np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + + def test_emd_emd2(): # test emd and emd2 for simple identity n = 100 |