summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/lp/__init__.py16
-rw-r--r--ot/lp/emd_wrap.pyx28
2 files changed, 29 insertions, 15 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 8edd8ec..0f40c19 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -12,7 +12,7 @@ import multiprocessing
import numpy as np
# import compiled emd
-from .emd_wrap import emd_c
+from .emd_wrap import emd_c, checkResult
from ..utils import parmap
@@ -94,12 +94,15 @@ def emd(a, b, M, numItermax=100000, log=False):
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- G, cost, u, v = emd_c(a, b, M, numItermax)
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ resultCodeString = checkResult(resultCode)
if log:
log = {}
log['cost'] = cost
log['u'] = u
log['v'] = v
+ log['warning'] = resultCodeString
+ log['resultCode'] = resultCode
return G, log
return G
@@ -177,15 +180,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
if log:
def f(b):
- G, cost, u, v = emd_c(a, b, M, numItermax)
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ resultCodeString = checkResult(resultCode)
log = {}
log['G'] = G
log['u'] = u
log['v'] = v
+ log['warning'] = resultCodeString
+ log['resultCode'] = resultCode
return [cost, log]
else:
def f(b):
- return emd_c(a, b, M, numItermax)[1]
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ checkResult(resultCode)
+ return cost
if len(b.shape) == 1:
return f(b)
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 5618dfc..19bcdd8 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -7,12 +7,12 @@ Cython linker with C solver
#
# License: MIT License
-import warnings
import numpy as np
cimport numpy as np
cimport cython
+import warnings
cdef extern from "EMD.h":
@@ -20,6 +20,19 @@ cdef extern from "EMD.h":
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
+def checkResult(resultCode):
+ if resultCode == OPTIMAL:
+ return None
+
+ if resultCode == INFEASIBLE:
+ message = "Problem infeasible. Check that a and b are in the simplex"
+ elif resultCode == UNBOUNDED:
+ message = "Problem unbounded"
+ elif resultCode == MAX_ITER_REACHED:
+ message = "numItermax reached before optimality. Try to increase numItermax."
+ warnings.warn(message)
+ return message
+
@cython.boundscheck(False)
@cython.wraparound(False)
@@ -77,13 +90,6 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
b=np.ones((n2,))/n2
# calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
- if resultSolver != OPTIMAL:
- if resultSolver == INFEASIBLE:
- warnings.warn("Problem infeasible. Check that a and b are in the simplex")
- elif resultSolver == UNBOUNDED:
- warnings.warn("Problem unbounded")
- elif resultSolver == MAX_ITER_REACHED:
- warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
-
- return G, cost, alpha, beta
+ cdef int resultCode = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
+
+ return G, cost, alpha, beta, resultCode