summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 8edd8ec..0f40c19 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -12,7 +12,7 @@ import multiprocessing
import numpy as np
# import compiled emd
-from .emd_wrap import emd_c
+from .emd_wrap import emd_c, checkResult
from ..utils import parmap
@@ -94,12 +94,15 @@ def emd(a, b, M, numItermax=100000, log=False):
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- G, cost, u, v = emd_c(a, b, M, numItermax)
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ resultCodeString = checkResult(resultCode)
if log:
log = {}
log['cost'] = cost
log['u'] = u
log['v'] = v
+ log['warning'] = resultCodeString
+ log['resultCode'] = resultCode
return G, log
return G
@@ -177,15 +180,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
if log:
def f(b):
- G, cost, u, v = emd_c(a, b, M, numItermax)
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ resultCodeString = checkResult(resultCode)
log = {}
log['G'] = G
log['u'] = u
log['v'] = v
+ log['warning'] = resultCodeString
+ log['resultCode'] = resultCode
return [cost, log]
else:
def f(b):
- return emd_c(a, b, M, numItermax)[1]
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ checkResult(resultCode)
+ return cost
if len(b.shape) == 1:
return f(b)