summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index bb9829a..eabdd3a 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -114,6 +114,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
+ "Dimension mismatch, check dimensions of M with a and b"
+
if dense:
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
else:
@@ -226,6 +229,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
+ "Dimension mismatch, check dimensions of M with a and b"
+
if log or return_matrix:
def f(b):
if dense: