diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-09-13 08:11:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-13 08:11:00 +0200 |
commit | a53ede95f916a11e2150ab7917820d813c0034bc (patch) | |
tree | 24304d83267d51b962e18722553973bbc75509f2 | |
parent | 62dcfbfb78a2be24379cd5cdb4aec70d8c4befaa (diff) | |
parent | e52b6eb41228a7f8e381cf73c06e0dffba5773be (diff) |
Merge pull request #29 from arolet/ot_dual_variables
Dual variables in EMD_wrapper
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | ot/lp/EMD.h | 5 | ||||
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 74 | ||||
-rw-r--r-- | ot/lp/__init__.py | 80 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 100 | ||||
-rw-r--r-- | ot/lp/network_simplex_simple.h | 98 | ||||
-rw-r--r-- | test/test_ot.py | 114 |
7 files changed, 281 insertions, 192 deletions
@@ -138,12 +138,12 @@ The contributors to this library are: * [Léo Gautheron](https://github.com/aje) (GPU implementation) * [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) * [Stanislas Chambon](https://slasnista.github.io/) +* [Antoine Rolet](https://arolet.github.io/) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): * [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab) * [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD) -* [Antoine Rolet](https://arolet.github.io/) ( Mex file for EMD ) * [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda) diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index aa92441..f42e222 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -26,9 +26,10 @@ typedef unsigned int node_id_type; enum ProblemType { INFEASIBLE, OPTIMAL, - UNBOUNDED + UNBOUNDED, + MAX_ITER_REACHED }; -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter); +int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index c8c2eb3..fc7ca63 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,62 +15,60 @@ #include "EMD.h" -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) { +int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! - int n, m, i,cur; - double max; + int n, m, i, cur; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); // Get the number of non zero coordinates for r and c n=0; - for (node_id_type i=0; i<n1; i++) { + for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { n++; - } + }else if(val<0){ + return INFEASIBLE; + } } m=0; - for (node_id_type i=0; i<n2; i++) { + for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { m++; - } + }else if(val<0){ + return INFEASIBLE; + } } - // Define the graph std::vector<int> indI(n), indJ(m); std::vector<double> weights1(n), weights2(m); Digraph di(n, m); - NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); // Set supply and demand, don't account for 0 values (faster) - max=0; cur=0; - for (node_id_type i=0; i<n1; i++) { + for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { - weights1[ di.nodeFromId(cur) ] = val; - max+=val; + weights1[ cur ] = val; indI[cur++]=i; } } // Demand is actually negative supply... - max=0; cur=0; - for (node_id_type i=0; i<n2; i++) { + for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { - weights2[ di.nodeFromId(cur) ] = -val; + weights2[ cur ] = -val; indJ[cur++]=i; - - max-=val; } } @@ -78,14 +76,10 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge - max=0; - for (node_id_type i=0; i<n; i++) { - for (node_id_type j=0; j<m; j++) { + for (int i=0; i<n; i++) { + for (int j=0; j<m; j++) { double val=*(D+indI[i]*n2+indJ[j]); net.setCost(di.arcFromId(i*m+j), val); - if (val>max) { - max=val; - } } } @@ -93,26 +87,20 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c // Solve the problem with the network simplex algorithm int ret=net.run(); - if (ret!=(int)net.OPTIMAL) { - if (ret==(int)net.INFEASIBLE) { - std::cout << "Infeasible problem"; + if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { + *cost = 0; + Arc a; di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a); + double flow = net.flow(a); + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + *(G+indI[i]*n2+indJ[j-n]) = flow; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); } - if (ret==(int)net.UNBOUNDED) - { - std::cout << "Unbounded problem"; - } - } else - { - for (node_id_type i=0; i<n; i++) - { - for (node_id_type j=0; j<m; j++) - { - *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j)); - } - }; - *cost = net.totalCost(); - - }; + + } return ret; diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index de91e74..5c09da2 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -7,14 +7,16 @@ Solvers for the original linear program OT problem # # License: MIT License +import multiprocessing + import numpy as np + # import compiled emd -from .emd_wrap import emd_c, emd2_c +from .emd_wrap import emd_c, check_result from ..utils import parmap -import multiprocessing -def emd(a, b, M, numItermax=100000): +def emd(a, b, M, numItermax=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -42,11 +44,17 @@ def emd(a, b, M, numItermax=100000): numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost and dual + variables. Otherwise returns only the optimal transportation matrix. Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + If input log is true, a dictionary containing the cost and dual + variables and exit status Examples @@ -82,14 +90,24 @@ def emd(a, b, M, numItermax=100000): # if empty array given then use unifor distributions if len(a) == 0: - a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - - return emd_c(a, b, M, numItermax) - - -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + result_code_string = check_result(result_code) + if log: + log = {} + log['cost'] = cost + log['u'] = u + log['v'] = v + log['warning'] = result_code_string + log['result_code'] = result_code + return G, log + return G + + +def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -116,11 +134,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost and dual + variables. Otherwise returns only the optimal transportation cost. + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log. Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + If input log is true, a dictionary containing the cost and dual + variables and exit status Examples @@ -156,17 +182,31 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): # if empty array given then use unifor distributions if len(a) == 0: - a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - if len(b.shape) == 1: - return emd2_c(a, b, M, numItermax) + if log or return_matrix: + def f(b): + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + result_code_string = check_result(resultCode) + log = {} + if return_matrix: + log['G'] = G + log['u'] = u + log['v'] = v + log['warning'] = result_code_string + log['result_code'] = resultCode + return [cost, log] else: - nb = b.shape[1] - # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] - def f(b): - return emd2_c(a, b, M, numItermax) - res = parmap(f, [b[:, i] for i in range(nb)], processes) - return np.array(res) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + check_result(result_code) + return cost + + if len(b.shape) == 1: + return f(b) + nb = b.shape[1] + + res = parmap(f, [b[:, i] for i in range(nb)], processes) + return res diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 26d3330..83ee6aa 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -12,81 +12,33 @@ cimport numpy as np cimport cython +import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) - cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) + cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED +def check_result(result_code): + if result_code == OPTIMAL: + return None -@cython.boundscheck(False) -@cython.wraparound(False) -def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter): - """ - Solves the Earth Movers distance problem and returns the optimal transport matrix - - gamm=emd(a,b,M) - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : + if result_code == INFEASIBLE: + message = "Problem infeasible. Check that a and b are in the simplex" + elif result_code == UNBOUNDED: + message = "Problem unbounded" + elif result_code == MAX_ITER_REACHED: + message = "numItermax reached before optimality. Try to increase numItermax." + warnings.warn(message) + return message - - M is the metric cost matrix - - a and b are the sample weights - - Parameters - ---------- - a : (ns,) ndarray, float64 - source histogram - b : (nt,) ndarray, float64 - target histogram - M : (ns,nt) ndarray, float64 - loss matrix - max_iter : int - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - - - Returns - ------- - gamma: (ns x nt) ndarray - Optimal transportation matrix for the given parameters - - """ - cdef int n1= M.shape[0] - cdef int n2= M.shape[1] - - cdef float cost=0 - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) - - if not len(a): - a=np.ones((n1,))/n1 - - if not len(b): - b=np.ones((n2,))/n2 - - # calling the function - cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter) - if resultSolver != OPTIMAL: - if resultSolver == INFEASIBLE: - print("Problem infeasible. Try to increase numItermax.") - elif resultSolver == UNBOUNDED: - print("Problem unbounded") - - return G @cython.boundscheck(False) @cython.wraparound(False) -def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): """ - Solves the Earth Movers distance problem and returns the optimal transport loss + Solves the Earth Movers distance problem and returns the optimal transport matrix gamm=emd(a,b,M) @@ -125,8 +77,11 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo cdef int n1= M.shape[0] cdef int n2= M.shape[1] - cdef float cost=0 + cdef double cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + if not len(a): a=np.ones((n1,))/n1 @@ -135,17 +90,6 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo b=np.ones((n2,))/n2 # calling the function - cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter) - if resultSolver != OPTIMAL: - if resultSolver == INFEASIBLE: - print("Problem infeasible. Try to inscrease numItermax.") - elif resultSolver == UNBOUNDED: - print("Problem unbounded") - - cost=0 - for i in range(n1): - for j in range(n2): - cost+=G[i,j]*M[i,j] - - return cost + cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + return G, cost, alpha, beta, result_code diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 64856a0..7c6a4ce 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -28,7 +28,14 @@ #ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H #define LEMON_NETWORK_SIMPLEX_SIMPLE_H #define DEBUG_LVL 0 -#define EPSILON 10*2.2204460492503131e-016 + +#if DEBUG_LVL>0 +#include <iomanip> +#endif + + +#define EPSILON 2.2204460492503131e-15 +#define _EPSILON 1e-8 #define MAX_DEBUG_ITER 100000 @@ -43,6 +50,7 @@ #include <vector> #include <limits> #include <algorithm> +#include <cstdio> #ifdef HASHMAP #include <hash_map> #else @@ -220,7 +228,7 @@ namespace lemon { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,double maxiters) : + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) : _graph(graph), //_arc_id(graph), _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits<Value>::max()), @@ -253,7 +261,9 @@ namespace lemon { /// The objective function of the problem is unbounded, i.e. /// there is a directed cycle having negative total cost and /// infinite upper bound. - UNBOUNDED + UNBOUNDED, + /// The maximum number of iteration has been reached + MAX_ITER_REACHED }; /// \brief Constants for selecting the type of the supply constraints. @@ -278,7 +288,7 @@ namespace lemon { private: - double max_iter; + int max_iter; TEMPLATE_DIGRAPH_TYPEDEFS(GR); typedef std::vector<int> IntVector; @@ -676,14 +686,12 @@ namespace lemon { /// \see resetParams(), reset() ProblemType run() { #if DEBUG_LVL>0 - mexPrintf("OPTIMAL = %d\nINFEASIBLE = %d\nUNBOUNDED = %d\n",OPTIMAL,INFEASIBLE,UNBOUNDED); - mexEvalString("drawnow;"); + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n"; #endif if (!init()) return INFEASIBLE; #if DEBUG_LVL>0 - mexPrintf("Init done, starting iterations\n"); - mexEvalString("drawnow;"); + std::cout << "Init done, starting iterations\n"; #endif return start(); } @@ -936,15 +944,15 @@ namespace lemon { // Initialize internal data structures bool init() { if (_node_num == 0) return false; - /* + // Check the sum of supply values _sum_supply = 0; for (int i = 0; i != _node_num; ++i) { _sum_supply += _supply[i]; } - if ( !((_stype == GEQ && _sum_supply <= _epsilon ) || - (_stype == LEQ && _sum_supply >= -_epsilon )) ) return false; - */ + if ( fabs(_sum_supply) > _EPSILON ) return false; + + _sum_supply = 0; // Initialize artifical cost Cost ART_COST; @@ -1411,27 +1419,26 @@ namespace lemon { ProblemType start() { PivotRuleImpl pivot(*this); double prevCost=-1; + ProblemType retVal = OPTIMAL; // Perform heuristic initial pivots if (!initialPivots()) return UNBOUNDED; -#if DEBUG_LVL>0 - int niter=0; -#endif int iter_number=0; //pivot.setDantzig(true); // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { - if(++iter_number>=max_iter&&max_iter>0){ + if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){ char errMess[1000]; - // sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher",iter_number ); - // mexWarnMsgTxt(errMess); + sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); + std::cerr << errMess; + retVal = MAX_ITER_REACHED; break; } #if DEBUG_LVL>0 - if(niter>MAX_DEBUG_ITER) + if(iter_number>MAX_DEBUG_ITER) break; - if(++niter%1000==0||niter%1000==1){ + if(iter_number%1000==0||iter_number%1000==1){ double curCost=totalCost(); double sumFlow=0; double a; @@ -1440,12 +1447,13 @@ namespace lemon { for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f\n",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a)); - mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc])); - mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]); - - mexPrintf("%.30f\n%.30f\n%.30f\n%.30f\n%",_cost[in_arc],_pi[_source[in_arc]],_pi[_target[in_arc]],a); - mexEvalString("drawnow;"); + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + std::cout << _cost[in_arc] << "\n"; + std::cout << _pi[_source[in_arc]] << "\n"; + std::cout << _pi[_target[in_arc]] << "\n"; + std::cout << a << "\n"; } #endif @@ -1459,11 +1467,11 @@ namespace lemon { } #if DEBUG_LVL>0 else{ - mexPrintf("No change\n"); + std::cout << "No change\n"; } #endif #if DEBUG_LVL>1 - mexPrintf("Arc in = (%d,%d)\n",_source[in_arc],_target[in_arc]); + std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; #endif } @@ -1478,34 +1486,36 @@ namespace lemon { for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a)); - mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc])); - mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]); + + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - mexEvalString("drawnow;"); #endif #if DEBUG_LVL>1 - double sumFlow=0; + sumFlow=0; for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; if (_state[i]==STATE_TREE) { - mexPrintf("Non zero value at (%d,%d)\n",_node_num+1-_source[i],_node_num+1-_target[i]); + std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; } } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\n",sumFlow,niter, totalCost()); - mexEvalString("drawnow;"); + std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; #endif // Check feasibility - for (int e = _search_arc_num; e != _all_arc_num; ++e) { - if (_flow[e] != 0){ - if (abs(_flow[e]) > EPSILON) - return INFEASIBLE; - else - _flow[e]=0; + if( retVal == OPTIMAL){ + for (int e = _search_arc_num; e != _all_arc_num; ++e) { + if (_flow[e] != 0){ + if (abs(_flow[e]) > EPSILON) + return INFEASIBLE; + else + _flow[e]=0; + } } - } + } // Shift potentials to meet the requirements of the GEQ/LEQ type // optimality conditions @@ -1531,7 +1541,7 @@ namespace lemon { } } - return OPTIMAL; + return retVal; } }; //class NetworkSimplexSimple diff --git a/test/test_ot.py b/test/test_ot.py index acd8718..ea6d9dc 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -4,12 +4,15 @@ # # License: MIT License +import warnings + import numpy as np + import ot +from ot.datasets import get_1D_gauss as gauss def test_doctest(): - import doctest # test lp solver @@ -66,9 +69,6 @@ def test_emd_empty(): def test_emd2_multi(): - - from ot.datasets import get_1D_gauss as gauss - n = 1000 # nb bins # bin positions @@ -100,3 +100,109 @@ def test_emd2_multi(): ot.toc('multi proc : {} s') np.testing.assert_allclose(emd1, emdn) + + # emd loss multipro proc with log + ot.tic() + emdn = ot.emd2(a, b, M, log=True, return_matrix=True) + ot.toc('multi proc : {} s') + + for i in range(len(emdn)): + emd = emdn[i] + log = emd[1] + cost = emd[0] + check_duality_gap(a, b[:, i], M, log['G'], log['u'], log['v'], cost) + emdn[i] = cost + + emdn = np.array(emdn) + np.testing.assert_allclose(emd1, emdn) + + +def test_warnings(): + n = 100 # nb bins + m = 100 # nb bins + + mean1 = 30 + mean2 = 50 + + # bin positions + x = np.arange(n, dtype=np.float64) + y = np.arange(m, dtype=np.float64) + + # Gaussian distributions + a = gauss(n, m=mean1, s=5) # m= mean, s= std + + b = gauss(m, m=mean2, s=10) + + # loss matrix + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) + + print('Computing {} EMD '.format(1)) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + print('Computing {} EMD '.format(1)) + ot.emd(a, b, M, numItermax=1) + assert "numItermax" in str(w[-1].message) + assert len(w) == 1 + a[0] = 100 + print('Computing {} EMD '.format(2)) + ot.emd(a, b, M) + assert "infeasible" in str(w[-1].message) + assert len(w) == 2 + a[0] = -1 + print('Computing {} EMD '.format(2)) + ot.emd(a, b, M) + assert "infeasible" in str(w[-1].message) + assert len(w) == 3 + + +def test_dual_variables(): + n = 5000 # nb bins + m = 6000 # nb bins + + mean1 = 1000 + mean2 = 1100 + + # bin positions + x = np.arange(n, dtype=np.float64) + y = np.arange(m, dtype=np.float64) + + # Gaussian distributions + a = gauss(n, m=mean1, s=5) # m= mean, s= std + + b = gauss(m, m=mean2, s=10) + + # loss matrix + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) + + print('Computing {} EMD '.format(1)) + + # emd loss 1 proc + ot.tic() + G, log = ot.emd(a, b, M, log=True) + ot.toc('1 proc : {} s') + + ot.tic() + G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) + ot.toc('1 proc : {} s') + + cost1 = (G * M).sum() + # Check symmetry + np.testing.assert_array_almost_equal(cost1, (M * G2.T).sum()) + # Check with closed-form solution for gaussians + np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2)) + + # Check that both cost computations are equivalent + np.testing.assert_almost_equal(cost1, log['cost']) + check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) + + +def check_duality_gap(a, b, M, G, u, v, cost): + cost_dual = np.vdot(a, u) + np.vdot(b, v) + # Check that dual and primal cost are equal + np.testing.assert_almost_equal(cost_dual, cost) + + [ind1, ind2] = np.nonzero(G) + + # Check that reduced cost is zero on transport arcs + np.testing.assert_array_almost_equal((M - u.reshape(-1, 1) - v.reshape(1, -1))[ind1, ind2], + np.zeros(ind1.size)) |