From 55a38f8253e5831105d2c329f4d8ed77686d1330 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Thu, 13 Jul 2017 15:30:39 +0900 Subject: Added optional maximal number of iteration --- ot/lp/EMD.h | 2 +- ot/lp/EMD_wrapper.cpp | 3 +-- ot/lp/__init__.py | 10 +++++----- ot/lp/emd_wrap.pyx | 10 +++++----- ot/lp/network_simplex_simple.h | 2 +- 5 files changed, 13 insertions(+), 14 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index 40d7192..59a5af8 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -24,6 +24,6 @@ using namespace lemon; typedef unsigned int node_id_type; -void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost); +void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index cad4750..2d448a0 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,11 +15,10 @@ #include "EMD.h" -void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost) { +void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) { // beware M and C anre strored in row major C style!!! int n, m, i,cur; double max; - int max_iter=10000; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index db3da78..673242d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,7 +11,7 @@ import multiprocessing -def emd(a, b, M): +def emd(a, b, M, max_iter=-1): """Solves the Earth Movers distance problem and returns the OT matrix @@ -80,9 +80,9 @@ def emd(a, b, M): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - return emd_c(a, b, M) + return emd_c(a, b, M, max_iter) -def emd2(a, b, M,processes=multiprocessing.cpu_count()): +def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -151,12 +151,12 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] if len(b.shape)==1: - return emd2_c(a, b, M) + return emd2_c(a, b, M, max_iter) else: nb=b.shape[1] #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)] def f(b): - return emd2_c(a,b,M) + return emd2_c(a,b,M, max_iter) res= parmap(f, [b[:,i] for i in range(nb)],processes) return np.array(res) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 46794ab..e8fdba4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -12,13 +12,13 @@ cimport cython cdef extern from "EMD.h": - void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost) + void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) @cython.boundscheck(False) @cython.wraparound(False) -def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): +def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int maxiter): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -66,13 +66,13 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost) + EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost, maxiter) return G @cython.boundscheck(False) @cython.wraparound(False) -def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): +def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int maxiter): """ Solves the Earth Movers distance problem and returns the optimal transport loss @@ -120,7 +120,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo b=np.ones((n2,))/n2 # calling the function - EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost) + EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost, maxiter) cost=0 for i in range(n1): diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 125c818..08449f6 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -1426,7 +1426,7 @@ namespace lemon { //pivot.setDantzig(true); // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { - if(++iter_number>=max_iter&&max_iter>0){ + if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){ char errMess[1000]; sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); std::cerr << errMess; -- cgit v1.2.3