diff options
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index f77c3d7..4cce41c 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -109,7 +109,7 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \ + assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" G, cost, u, v, result_code = emd_c(a, b, M, numItermax) @@ -215,7 +215,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \ + assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" if log or return_matrix: |