From e58cd780ccf87736265e4e1a39afa3a167325ccc Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 12:37:56 +0900 Subject: Added convergence status to the log --- ot/lp/__init__.py | 16 ++++++++++++---- ot/lp/emd_wrap.pyx | 28 +++++++++++++++++----------- 2 files changed, 29 insertions(+), 15 deletions(-) (limited to 'ot/lp') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 8edd8ec..0f40c19 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,7 +12,7 @@ import multiprocessing import numpy as np # import compiled emd -from .emd_wrap import emd_c +from .emd_wrap import emd_c, checkResult from ..utils import parmap @@ -94,12 +94,15 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v = emd_c(a, b, M, numItermax) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + resultCodeString = checkResult(resultCode) if log: log = {} log['cost'] = cost log['u'] = u log['v'] = v + log['warning'] = resultCodeString + log['resultCode'] = resultCode return G, log return G @@ -177,15 +180,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= if log: def f(b): - G, cost, u, v = emd_c(a, b, M, numItermax) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + resultCodeString = checkResult(resultCode) log = {} log['G'] = G log['u'] = u log['v'] = v + log['warning'] = resultCodeString + log['resultCode'] = resultCode return [cost, log] else: def f(b): - return emd_c(a, b, M, numItermax)[1] + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + checkResult(resultCode) + return cost if len(b.shape) == 1: return f(b) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 5618dfc..19bcdd8 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -7,12 +7,12 @@ Cython linker with C solver # # License: MIT License -import warnings import numpy as np cimport numpy as np cimport cython +import warnings cdef extern from "EMD.h": @@ -20,6 +20,19 @@ cdef extern from "EMD.h": cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED +def checkResult(resultCode): + if resultCode == OPTIMAL: + return None + + if resultCode == INFEASIBLE: + message = "Problem infeasible. Check that a and b are in the simplex" + elif resultCode == UNBOUNDED: + message = "Problem unbounded" + elif resultCode == MAX_ITER_REACHED: + message = "numItermax reached before optimality. Try to increase numItermax." + warnings.warn(message) + return message + @cython.boundscheck(False) @cython.wraparound(False) @@ -77,13 +90,6 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - cdef int resultSolver = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, numItermax) - if resultSolver != OPTIMAL: - if resultSolver == INFEASIBLE: - warnings.warn("Problem infeasible. Check that a and b are in the simplex") - elif resultSolver == UNBOUNDED: - warnings.warn("Problem unbounded") - elif resultSolver == MAX_ITER_REACHED: - warnings.warn("numItermax reached before optimality. Try to increase numItermax.") - - return G, cost, alpha, beta + cdef int resultCode = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, numItermax) + + return G, cost, alpha, beta, resultCode -- cgit v1.2.3