summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorKilian <kilian.fatras@gmail.com>2019-12-10 11:23:50 +0100
committerKilian <kilian.fatras@gmail.com>2019-12-10 11:23:50 +0100
commita9bbc2cfdffd22ceee3256102e470df6c25338f3 (patch)
tree7b379bfce4cc552ade36130e3a5a7836e5d5d9e8 /ot/lp/__init__.py
parent92dbe259032d340a259209e477e9aac74897689e (diff)
change or in assert by and
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index f77c3d7..4cce41c 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -109,7 +109,7 @@ def emd(a, b, M, numItermax=100000, log=False):
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \
+ assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
@@ -215,7 +215,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \
+ assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
if log or return_matrix: