From e58cd780ccf87736265e4e1a39afa3a167325ccc Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 12:37:56 +0900 Subject: Added convergence status to the log --- ot/lp/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'ot/lp/__init__.py') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 8edd8ec..0f40c19 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,7 +12,7 @@ import multiprocessing import numpy as np # import compiled emd -from .emd_wrap import emd_c +from .emd_wrap import emd_c, checkResult from ..utils import parmap @@ -94,12 +94,15 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v = emd_c(a, b, M, numItermax) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + resultCodeString = checkResult(resultCode) if log: log = {} log['cost'] = cost log['u'] = u log['v'] = v + log['warning'] = resultCodeString + log['resultCode'] = resultCode return G, log return G @@ -177,15 +180,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= if log: def f(b): - G, cost, u, v = emd_c(a, b, M, numItermax) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + resultCodeString = checkResult(resultCode) log = {} log['G'] = G log['u'] = u log['v'] = v + log['warning'] = resultCodeString + log['resultCode'] = resultCode return [cost, log] else: def f(b): - return emd_c(a, b, M, numItermax)[1] + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + checkResult(resultCode) + return cost if len(b.shape) == 1: return f(b) -- cgit v1.2.3