From 92233f79e098f1930248d815e66c0a929508af59 Mon Sep 17 00:00:00 2001 From: Kilian Date: Mon, 9 Dec 2019 15:56:48 +0100 Subject: add assert for emd dimension mismatch --- ot/lp/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'ot/lp/__init__.py') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 0c92810..f77c3d7 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -109,6 +109,9 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \ + "Dimension mismatch, check dimensions of M with a and b" + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) result_code_string = check_result(result_code) if log: @@ -212,6 +215,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \ + "Dimension mismatch, check dimensions of M with a and b" + if log or return_matrix: def f(b): G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) -- cgit v1.2.3