diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/lp/EMD.h | 2 | ||||
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 50 | ||||
-rw-r--r-- | ot/lp/__init__.py | 29 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 26 | ||||
-rw-r--r-- | ot/lp/network_simplex_simple.h | 55 |
5 files changed, 81 insertions, 81 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index aa92441..15e9115 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -29,6 +29,6 @@ enum ProblemType { UNBOUNDED }; -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter); +int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index c8c2eb3..8ac43c7 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,9 +15,10 @@ #include "EMD.h" -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) { +int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, int max_iter) { // beware M and C anre strored in row major C style!!! - int n, m, i,cur; + int n, m, i, cur; double max; typedef FullBipartiteDigraph Digraph; @@ -25,21 +26,20 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c // Get the number of non zero coordinates for r and c n=0; - for (node_id_type i=0; i<n1; i++) { + for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { n++; } } m=0; - for (node_id_type i=0; i<n2; i++) { + for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { m++; } } - // Define the graph std::vector<int> indI(n), indJ(m); @@ -49,28 +49,23 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c // Set supply and demand, don't account for 0 values (faster) - max=0; cur=0; - for (node_id_type i=0; i<n1; i++) { + for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { - weights1[ di.nodeFromId(cur) ] = val; - max+=val; + weights1[ cur ] = val; indI[cur++]=i; } } // Demand is actually negative supply... - max=0; cur=0; - for (node_id_type i=0; i<n2; i++) { + for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { - weights2[ di.nodeFromId(cur) ] = -val; + weights2[ cur ] = -val; indJ[cur++]=i; - - max-=val; } } @@ -78,14 +73,10 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge - max=0; - for (node_id_type i=0; i<n; i++) { - for (node_id_type j=0; j<m; j++) { + for (int i=0; i<n; i++) { + for (int j=0; j<m; j++) { double val=*(D+indI[i]*n2+indJ[j]); net.setCost(di.arcFromId(i*m+j), val); - if (val>max) { - max=val; - } } } @@ -103,14 +94,17 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c } } else { - for (node_id_type i=0; i<n; i++) - { - for (node_id_type j=0; j<m; j++) - { - *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j)); - } - }; - *cost = net.totalCost(); + *cost = 0; + Arc a; di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a); + double flow = net.flow(a); + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + *(G+indI[i]*n2+indJ[j-n]) = flow; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); + } }; diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index de91e74..a14d4e4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -14,7 +14,7 @@ from ..utils import parmap import multiprocessing -def emd(a, b, M, numItermax=100000): +def emd(a, b, M, numItermax=100000, dual_variables=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -86,8 +86,10 @@ def emd(a, b, M, numItermax=100000): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - return emd_c(a, b, M, numItermax) - + G, alpha, beta = emd_c(a, b, M, numItermax) + if dual_variables: + return G, alpha, beta + return G def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): """Solves the Earth Movers distance problem and returns the loss @@ -159,14 +161,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - - if len(b.shape) == 1: - return emd2_c(a, b, M, numItermax) - else: - nb = b.shape[1] - # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] - - def f(b): - return emd2_c(a, b, M, numItermax) - res = parmap(f, [b[:, i] for i in range(nb)], processes) - return np.array(res) + + if len(b.shape)==1: + return emd2_c(a, b, M, numItermax)[0] + nb = b.shape[1] + # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] + + def f(b): + return emd2_c(a,b,M, max_iter)[0] + res= parmap(f, [b[:,i] for i in range(nb)],processes) + return np.array(res) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 26d3330..7056e0e 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -15,7 +15,7 @@ cimport cython cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter) cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED @@ -63,8 +63,11 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod cdef int n1= M.shape[0] cdef int n2= M.shape[1] - cdef float cost=0 + cdef double cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + if not len(a): a=np.ones((n1,))/n1 @@ -73,14 +76,14 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter) + cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) if resultSolver != OPTIMAL: if resultSolver == INFEASIBLE: print("Problem infeasible. Try to increase numItermax.") elif resultSolver == UNBOUNDED: print("Problem unbounded") - return G + return G, alpha, beta @cython.boundscheck(False) @cython.wraparound(False) @@ -125,27 +128,24 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo cdef int n1= M.shape[0] cdef int n2= M.shape[1] - cdef float cost=0 + cdef double cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1]) + cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2]) + if not len(a): a=np.ones((n1,))/n1 if not len(b): b=np.ones((n2,))/n2 - # calling the function - cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter) + cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) if resultSolver != OPTIMAL: if resultSolver == INFEASIBLE: print("Problem infeasible. Try to inscrease numItermax.") elif resultSolver == UNBOUNDED: print("Problem unbounded") - cost=0 - for i in range(n1): - for j in range(n2): - cost+=G[i,j]*M[i,j] - - return cost + return cost, alpha, beta diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 64856a0..08449f6 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -28,6 +28,12 @@ #ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H #define LEMON_NETWORK_SIMPLEX_SIMPLE_H #define DEBUG_LVL 0 + +#if DEBUG_LVL>0 +#include <iomanip> +#endif + + #define EPSILON 10*2.2204460492503131e-016 #define MAX_DEBUG_ITER 100000 @@ -220,7 +226,7 @@ namespace lemon { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,double maxiters) : + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) : _graph(graph), //_arc_id(graph), _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits<Value>::max()), @@ -278,7 +284,7 @@ namespace lemon { private: - double max_iter; + int max_iter; TEMPLATE_DIGRAPH_TYPEDEFS(GR); typedef std::vector<int> IntVector; @@ -676,14 +682,12 @@ namespace lemon { /// \see resetParams(), reset() ProblemType run() { #if DEBUG_LVL>0 - mexPrintf("OPTIMAL = %d\nINFEASIBLE = %d\nUNBOUNDED = %d\n",OPTIMAL,INFEASIBLE,UNBOUNDED); - mexEvalString("drawnow;"); + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "nUNBOUNDED = " << UNBOUNDED << "\n"; #endif if (!init()) return INFEASIBLE; #if DEBUG_LVL>0 - mexPrintf("Init done, starting iterations\n"); - mexEvalString("drawnow;"); + std::cout << "Init done, starting iterations\n"; #endif return start(); } @@ -1422,10 +1426,10 @@ namespace lemon { //pivot.setDantzig(true); // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { - if(++iter_number>=max_iter&&max_iter>0){ + if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){ char errMess[1000]; - // sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher",iter_number ); - // mexWarnMsgTxt(errMess); + sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); + std::cerr << errMess; break; } #if DEBUG_LVL>0 @@ -1440,12 +1444,13 @@ namespace lemon { for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f\n",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a)); - mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc])); - mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]); - - mexPrintf("%.30f\n%.30f\n%.30f\n%.30f\n%",_cost[in_arc],_pi[_source[in_arc]],_pi[_target[in_arc]],a); - mexEvalString("drawnow;"); + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + std::cout << _cost[in_arc] << "\n"; + std::cout << _pi[_source[in_arc]] << "\n"; + std::cout << _pi[_target[in_arc]] << "\n"; + std::cout << a << "\n"; } #endif @@ -1459,11 +1464,11 @@ namespace lemon { } #if DEBUG_LVL>0 else{ - mexPrintf("No change\n"); + std::cout << "No change\n"; } #endif #if DEBUG_LVL>1 - mexPrintf("Arc in = (%d,%d)\n",_source[in_arc],_target[in_arc]); + std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; #endif } @@ -1478,23 +1483,23 @@ namespace lemon { for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a)); - mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc])); - mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]); + + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - mexEvalString("drawnow;"); #endif #if DEBUG_LVL>1 - double sumFlow=0; + sumFlow=0; for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; if (_state[i]==STATE_TREE) { - mexPrintf("Non zero value at (%d,%d)\n",_node_num+1-_source[i],_node_num+1-_target[i]); + std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; } } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\n",sumFlow,niter, totalCost()); - mexEvalString("drawnow;"); + std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; #endif // Check feasibility for (int e = _search_arc_num; e != _all_arc_num; ++e) { |