summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-09 12:37:56 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-09 12:37:56 +0900
commite58cd780ccf87736265e4e1a39afa3a167325ccc (patch)
treeeacf267e1ee18372054728961163c5bce19f3a06 /ot/lp/__init__.py
parenta37e52e64f300fa0165a58932d5ac0ef1dd8c6f7 (diff)
Added convergence status to the log
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 8edd8ec..0f40c19 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -12,7 +12,7 @@ import multiprocessing
import numpy as np
# import compiled emd
-from .emd_wrap import emd_c
+from .emd_wrap import emd_c, checkResult
from ..utils import parmap
@@ -94,12 +94,15 @@ def emd(a, b, M, numItermax=100000, log=False):
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- G, cost, u, v = emd_c(a, b, M, numItermax)
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ resultCodeString = checkResult(resultCode)
if log:
log = {}
log['cost'] = cost
log['u'] = u
log['v'] = v
+ log['warning'] = resultCodeString
+ log['resultCode'] = resultCode
return G, log
return G
@@ -177,15 +180,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
if log:
def f(b):
- G, cost, u, v = emd_c(a, b, M, numItermax)
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ resultCodeString = checkResult(resultCode)
log = {}
log['G'] = G
log['u'] = u
log['v'] = v
+ log['warning'] = resultCodeString
+ log['resultCode'] = resultCode
return [cost, log]
else:
def f(b):
- return emd_c(a, b, M, numItermax)[1]
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ checkResult(resultCode)
+ return cost
if len(b.shape) == 1:
return f(b)