diff options
author | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-09 12:37:56 +0900 |
---|---|---|
committer | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-09 12:37:56 +0900 |
commit | e58cd780ccf87736265e4e1a39afa3a167325ccc (patch) | |
tree | eacf267e1ee18372054728961163c5bce19f3a06 /ot/lp/__init__.py | |
parent | a37e52e64f300fa0165a58932d5ac0ef1dd8c6f7 (diff) |
Added convergence status to the log
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 8edd8ec..0f40c19 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,7 +12,7 @@ import multiprocessing import numpy as np # import compiled emd -from .emd_wrap import emd_c +from .emd_wrap import emd_c, checkResult from ..utils import parmap @@ -94,12 +94,15 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v = emd_c(a, b, M, numItermax) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + resultCodeString = checkResult(resultCode) if log: log = {} log['cost'] = cost log['u'] = u log['v'] = v + log['warning'] = resultCodeString + log['resultCode'] = resultCode return G, log return G @@ -177,15 +180,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= if log: def f(b): - G, cost, u, v = emd_c(a, b, M, numItermax) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + resultCodeString = checkResult(resultCode) log = {} log['G'] = G log['u'] = u log['v'] = v + log['warning'] = resultCodeString + log['resultCode'] = resultCode return [cost, log] else: def f(b): - return emd_c(a, b, M, numItermax)[1] + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + checkResult(resultCode) + return cost if len(b.shape) == 1: return f(b) |