From 0e86d1bdbc0dcf7ffdb943637f62df5de4612ad0 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Thu, 13 Jul 2017 15:19:42 +0900 Subject: Removed references to matlab Also: - added error message when maxiter is reached - added debug logs --- ot/lp/network_simplex_simple.h | 53 +++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 24 deletions(-) (limited to 'ot') diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 64856a0..125c818 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -28,6 +28,12 @@ #ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H #define LEMON_NETWORK_SIMPLEX_SIMPLE_H #define DEBUG_LVL 0 + +#if DEBUG_LVL>0 +#include +#endif + + #define EPSILON 10*2.2204460492503131e-016 #define MAX_DEBUG_ITER 100000 @@ -220,7 +226,7 @@ namespace lemon { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,double maxiters) : + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) : _graph(graph), //_arc_id(graph), _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), @@ -278,7 +284,7 @@ namespace lemon { private: - double max_iter; + int max_iter; TEMPLATE_DIGRAPH_TYPEDEFS(GR); typedef std::vector IntVector; @@ -676,14 +682,12 @@ namespace lemon { /// \see resetParams(), reset() ProblemType run() { #if DEBUG_LVL>0 - mexPrintf("OPTIMAL = %d\nINFEASIBLE = %d\nUNBOUNDED = %d\n",OPTIMAL,INFEASIBLE,UNBOUNDED); - mexEvalString("drawnow;"); + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "nUNBOUNDED = " << UNBOUNDED << "\n"; #endif if (!init()) return INFEASIBLE; #if DEBUG_LVL>0 - mexPrintf("Init done, starting iterations\n"); - mexEvalString("drawnow;"); + std::cout << "Init done, starting iterations\n"; #endif return start(); } @@ -1424,8 +1428,8 @@ namespace lemon { while (pivot.findEnteringArc()) { if(++iter_number>=max_iter&&max_iter>0){ char errMess[1000]; - // sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher",iter_number ); - // mexWarnMsgTxt(errMess); + sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); + std::cerr << errMess; break; } #if DEBUG_LVL>0 @@ -1440,12 +1444,13 @@ namespace lemon { for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f\n",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a)); - mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc])); - mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]); - - mexPrintf("%.30f\n%.30f\n%.30f\n%.30f\n%",_cost[in_arc],_pi[_source[in_arc]],_pi[_target[in_arc]],a); - mexEvalString("drawnow;"); + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + std::cout << _cost[in_arc] << "\n"; + std::cout << _pi[_source[in_arc]] << "\n"; + std::cout << _pi[_target[in_arc]] << "\n"; + std::cout << a << "\n"; } #endif @@ -1459,11 +1464,11 @@ namespace lemon { } #if DEBUG_LVL>0 else{ - mexPrintf("No change\n"); + std::cout << "No change\n"; } #endif #if DEBUG_LVL>1 - mexPrintf("Arc in = (%d,%d)\n",_source[in_arc],_target[in_arc]); + std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; #endif } @@ -1478,23 +1483,23 @@ namespace lemon { for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a)); - mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc])); - mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]); + + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - mexEvalString("drawnow;"); #endif #if DEBUG_LVL>1 - double sumFlow=0; + sumFlow=0; for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; if (_state[i]==STATE_TREE) { - mexPrintf("Non zero value at (%d,%d)\n",_node_num+1-_source[i],_node_num+1-_target[i]); + std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; } } - mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\n",sumFlow,niter, totalCost()); - mexEvalString("drawnow;"); + std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; #endif // Check feasibility for (int e = _search_arc_num; e != _all_arc_num; ++e) { -- cgit v1.2.3 From 55a38f8253e5831105d2c329f4d8ed77686d1330 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Thu, 13 Jul 2017 15:30:39 +0900 Subject: Added optional maximal number of iteration --- ot/lp/EMD.h | 2 +- ot/lp/EMD_wrapper.cpp | 3 +-- ot/lp/__init__.py | 10 +++++----- ot/lp/emd_wrap.pyx | 10 +++++----- ot/lp/network_simplex_simple.h | 2 +- 5 files changed, 13 insertions(+), 14 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index 40d7192..59a5af8 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -24,6 +24,6 @@ using namespace lemon; typedef unsigned int node_id_type; -void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost); +void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index cad4750..2d448a0 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,11 +15,10 @@ #include "EMD.h" -void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost) { +void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) { // beware M and C anre strored in row major C style!!! int n, m, i,cur; double max; - int max_iter=10000; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index db3da78..673242d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,7 +11,7 @@ import multiprocessing -def emd(a, b, M): +def emd(a, b, M, max_iter=-1): """Solves the Earth Movers distance problem and returns the OT matrix @@ -80,9 +80,9 @@ def emd(a, b, M): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - return emd_c(a, b, M) + return emd_c(a, b, M, max_iter) -def emd2(a, b, M,processes=multiprocessing.cpu_count()): +def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -151,12 +151,12 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()): b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] if len(b.shape)==1: - return emd2_c(a, b, M) + return emd2_c(a, b, M, max_iter) else: nb=b.shape[1] #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)] def f(b): - return emd2_c(a,b,M) + return emd2_c(a,b,M, max_iter) res= parmap(f, [b[:,i] for i in range(nb)],processes) return np.array(res) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 46794ab..e8fdba4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -12,13 +12,13 @@ cimport cython cdef extern from "EMD.h": - void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost) + void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) @cython.boundscheck(False) @cython.wraparound(False) -def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): +def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int maxiter): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -66,13 +66,13 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost) + EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost, maxiter) return G @cython.boundscheck(False) @cython.wraparound(False) -def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): +def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int maxiter): """ Solves the Earth Movers distance problem and returns the optimal transport loss @@ -120,7 +120,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo b=np.ones((n2,))/n2 # calling the function - EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost) + EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost, maxiter) cost=0 for i in range(n1): diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 125c818..08449f6 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -1426,7 +1426,7 @@ namespace lemon { //pivot.setDantzig(true); // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { - if(++iter_number>=max_iter&&max_iter>0){ + if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){ char errMess[1000]; sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); std::cerr << errMess; -- cgit v1.2.3 From 0faef7fde7e64705b4f0ed6618a0cfd25319bdc7 Mon Sep 17 00:00:00 2001 From: arolet Date: Fri, 14 Jul 2017 15:19:55 +0900 Subject: Removed unused variable max Probably a legacy normalization variable --- ot/lp/EMD_wrapper.cpp | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 2d448a0..d97ba46 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,10 +15,10 @@ #include "EMD.h" -void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) { +void EMD_wrap(int n1, int n2, double *X, double *Y, + double *D, double *G, double *cost, int max_iter) { // beware M and C anre strored in row major C style!!! int n, m, i,cur; - double max; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); @@ -39,7 +39,6 @@ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double * } } - // Define the graph std::vector indI(n), indJ(m); @@ -49,28 +48,23 @@ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double * // Set supply and demand, don't account for 0 values (faster) - max=0; cur=0; for (node_id_type i=0; i0) { weights1[ di.nodeFromId(cur) ] = val; - max+=val; indI[cur++]=i; } } // Demand is actually negative supply... - max=0; cur=0; for (node_id_type i=0; i0) { weights2[ di.nodeFromId(cur) ] = -val; indJ[cur++]=i; - - max-=val; } } @@ -78,14 +72,10 @@ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double * net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge - max=0; for (node_id_type i=0; imax) { - max=val; - } } } -- cgit v1.2.3 From 1fcb7d0ffbc5b00ed20b5ded2e7f1001dc914d6e Mon Sep 17 00:00:00 2001 From: arolet Date: Fri, 14 Jul 2017 15:38:20 +0900 Subject: Removed some references to node_id_type node_id_type is really always int, it makes code hard to read though. In lemon they needed the typedef because they have more complicated graphs. --- ot/lp/EMD_wrapper.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index d97ba46..d719c6e 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -25,14 +25,14 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, // Get the number of non zero coordinates for r and c n=0; - for (node_id_type i=0; i0) { n++; } } m=0; - for (node_id_type i=0; i0) { m++; @@ -49,10 +49,10 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, // Set supply and demand, don't account for 0 values (faster) cur=0; - for (node_id_type i=0; i0) { - weights1[ di.nodeFromId(cur) ] = val; + weights1[ cur ] = val; indI[cur++]=i; } } @@ -60,10 +60,10 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, // Demand is actually negative supply... cur=0; - for (node_id_type i=0; i0) { - weights2[ di.nodeFromId(cur) ] = -val; + weights2[ cur ] = -val; indJ[cur++]=i; } } @@ -72,8 +72,8 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge - for (node_id_type i=0; i Date: Fri, 21 Jul 2017 12:12:21 +0900 Subject: Cleaned optimal plan and optimal cost computation --- ot/lp/EMD_wrapper.cpp | 13 ++++++------- ot/lp/emd_wrap.pyx | 5 ----- test/test_emd.py | 10 ++++++++-- 3 files changed, 14 insertions(+), 14 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index d719c6e..cc13230 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -93,14 +93,13 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, } } else { - for (node_id_type i=0; i a.data, b.data, M.data, G.data, &cost, maxiter) - - cost=0 - for i in range(n1): - for j in range(n2): - cost+=G[i,j]*M[i,j] return cost diff --git a/test/test_emd.py b/test/test_emd.py index eb1c5c5..4757cd1 100644 --- a/test/test_emd.py +++ b/test/test_emd.py @@ -43,11 +43,17 @@ ot.toc('1 proc : {} s') cost1 = (G * M).sum() +# emd loss 1 proc +ot.tic() +cost_emd2 = ot.emd2(a,b,M) +ot.toc('1 proc : {} s') + ot.tic() G = ot.emd(b, a, np.ascontiguousarray(M.T)) ot.toc('1 proc : {} s') cost2 = (G * M.T).sum() -assert np.abs(cost1-cost2) < tol -assert np.abs(cost1-np.abs(mean1-mean2)) < tol +assert np.abs(cost1-cost_emd2)/np.abs(cost1) < tol +assert np.abs(cost1-cost2)/np.abs(cost1) < tol +assert np.abs(cost1-np.abs(mean1-mean2))/np.abs(cost1) < tol -- cgit v1.2.3 From 88c62c39a9623e8b58ebb776a9deddc96b43b4a0 Mon Sep 17 00:00:00 2001 From: arolet Date: Fri, 21 Jul 2017 12:12:48 +0900 Subject: Added dual variables computations --- ot/lp/EMD.h | 3 ++- ot/lp/EMD_wrapper.cpp | 6 ++++-- ot/lp/__init__.py | 11 +++++++---- ot/lp/emd_wrap.pyx | 22 ++++++++++++++++------ 4 files changed, 29 insertions(+), 13 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index 59a5af8..fb7feca 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -24,6 +24,7 @@ using namespace lemon; typedef unsigned int node_id_type; -void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter); +void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, + double* alpha, double* beta, double *cost, int max_iter); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index cc13230..6bda6a7 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,8 +15,8 @@ #include "EMD.h" -void EMD_wrap(int n1, int n2, double *X, double *Y, - double *D, double *G, double *cost, int max_iter) { +void EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, int max_iter) { // beware M and C anre strored in row major C style!!! int n, m, i,cur; @@ -99,6 +99,8 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, int i = di.source(a); int j = di.target(a); *(G+indI[i]*n2+indJ[j-n]) = net.flow(a); + *(alpha + indI[i]) = net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); } }; diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 673242d..915a18c 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -11,7 +11,7 @@ import multiprocessing -def emd(a, b, M, max_iter=-1): +def emd(a, b, M, dual_variables=False, max_iter=-1): """Solves the Earth Movers distance problem and returns the OT matrix @@ -80,7 +80,10 @@ def emd(a, b, M, max_iter=-1): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - return emd_c(a, b, M, max_iter) + G, alpha, beta = emd_c(a, b, M, max_iter) + if dual_variables: + return G, alpha, beta + return G def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1): """Solves the Earth Movers distance problem and returns the loss @@ -151,12 +154,12 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1): b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] if len(b.shape)==1: - return emd2_c(a, b, M, max_iter) + return emd2_c(a, b, M, max_iter)[0] else: nb=b.shape[1] #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)] def f(b): - return emd2_c(a,b,M, max_iter) + return emd2_c(a,b,M, max_iter)[0] res= parmap(f, [b[:,i] for i in range(nb)],processes) return np.array(res) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index c4ba125..813596f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -12,7 +12,8 @@ cimport cython cdef extern from "EMD.h": - void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) + void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, + double* alpha, double* beta, int max_iter) @@ -58,6 +59,8 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod cdef float cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) if not len(a): a=np.ones((n1,))/n1 @@ -65,10 +68,13 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 + print alpha.size + print beta.size # calling the function - EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost, maxiter) + EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, + alpha.data, beta.data, &cost, maxiter) - return G + return G, alpha, beta @cython.boundscheck(False) @cython.wraparound(False) @@ -112,15 +118,19 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo cdef float cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + + cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1]) + cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2]) if not len(a): a=np.ones((n1,))/n1 if not len(b): b=np.ones((n2,))/n2 - # calling the function - EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost, maxiter) + EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, + alpha.data, beta.data, &cost, maxiter) + - return cost + return cost, alpha, beta -- cgit v1.2.3 From db2a70b1f5146d6374af57f4bea66ab95b202e77 Mon Sep 17 00:00:00 2001 From: arolet Date: Fri, 21 Jul 2017 13:33:44 +0900 Subject: Compute cost with primal --- ot/lp/EMD_wrapper.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 6bda6a7..c6cbb04 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -93,12 +93,14 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, } } else { - *cost = net.totalCost(); + *cost = 0; Arc a; di.first(a); for (; a != INVALID; di.next(a)) { int i = di.source(a); int j = di.target(a); - *(G+indI[i]*n2+indJ[j-n]) = net.flow(a); + double flow = net.flow(a); + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + *(G+indI[i]*n2+indJ[j-n]) = flow; *(alpha + indI[i]) = net.potential(i); *(beta + indJ[j-n]) = net.potential(j); } -- cgit v1.2.3 From c1980a414c879dd1bc1d8904fd43426326741385 Mon Sep 17 00:00:00 2001 From: arolet Date: Fri, 21 Jul 2017 13:34:09 +0900 Subject: Added and passed tests for dual variables --- ot/lp/EMD_wrapper.cpp | 2 +- ot/lp/emd_wrap.pyx | 4 ++-- test/test_emd.py | 28 +++++++++++++++++++--------- 3 files changed, 22 insertions(+), 12 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index c6cbb04..0977e75 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -101,7 +101,7 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double flow = net.flow(a); *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); *(G+indI[i]*n2+indJ[j-n]) = flow; - *(alpha + indI[i]) = net.potential(i); + *(alpha + indI[i]) = -net.potential(i); *(beta + indJ[j-n]) = net.potential(j); } diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 813596f..435a270 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -57,7 +57,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod cdef int n1= M.shape[0] cdef int n2= M.shape[1] - cdef float cost=0 + cdef double cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) @@ -116,7 +116,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo cdef int n1= M.shape[0] cdef int n2= M.shape[1] - cdef float cost=0 + cdef double cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1]) diff --git a/test/test_emd.py b/test/test_emd.py index 4757cd1..3bf6fa2 100644 --- a/test/test_emd.py +++ b/test/test_emd.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- import numpy as np -import pylab as pl import ot from ot.datasets import get_1D_gauss as gauss @@ -16,8 +15,6 @@ m=6000 # nb bins mean1 = 1000 mean2 = 1100 -tol = 1e-6 - # bin positions x=np.arange(n,dtype=np.float64) y=np.arange(m,dtype=np.float64) @@ -38,10 +35,11 @@ print('Computing {} EMD '.format(1)) # emd loss 1 proc ot.tic() -G = ot.emd(a,b,M) +G, alpha, beta = ot.emd(a,b,M, dual_variables=True) ot.toc('1 proc : {} s') cost1 = (G * M).sum() +cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) # emd loss 1 proc ot.tic() @@ -49,11 +47,23 @@ cost_emd2 = ot.emd2(a,b,M) ot.toc('1 proc : {} s') ot.tic() -G = ot.emd(b, a, np.ascontiguousarray(M.T)) +G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) ot.toc('1 proc : {} s') -cost2 = (G * M.T).sum() +cost2 = (G2 * M.T).sum() + +M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1) + +# Check that both cost computations are equivalent +np.testing.assert_almost_equal(cost1, cost_emd2) +# Check that dual and primal cost are equal +np.testing.assert_almost_equal(cost1, cost_dual) +# Check symmetry +np.testing.assert_almost_equal(cost1, cost2) +# Check with closed-form solution for gaussians +np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2)) + +[ind1, ind2] = np.nonzero(G) -assert np.abs(cost1-cost_emd2)/np.abs(cost1) < tol -assert np.abs(cost1-cost2)/np.abs(cost1) < tol -assert np.abs(cost1-np.abs(mean1-mean2))/np.abs(cost1) < tol +# Check that reduced cost is zero on transport arcs +np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) \ No newline at end of file -- cgit v1.2.3 From 30bfc5ce5acd98991b3d01e313d0c14f0e600b14 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Mon, 4 Sep 2017 08:46:36 +0200 Subject: correction semi supervised case --- ot/da.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 564c7b7..e694668 100644 --- a/ot/da.py +++ b/ot/da.py @@ -989,7 +989,7 @@ class BaseTransport(BaseEstimator): # assumes labeled source samples occupy the first rows # and labeled target samples occupy the first columns - classes = np.unique(ys) + classes = [c for c in np.unique(ys) if c != -1] for c in classes: idx_s = np.where((ys != c) & (ys != -1)) idx_t = np.where(yt == c) -- cgit v1.2.3 From 363c5f92a4865527320edcff97036e62a7ca28c9 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Mon, 4 Sep 2017 09:12:32 +0200 Subject: doc string + example --- examples/da/plot_otda_semi_supervised.py | 142 +++++++++++++++++++++++++++++++ ot/da.py | 72 ++++++++++++---- 2 files changed, 196 insertions(+), 18 deletions(-) create mode 100644 examples/da/plot_otda_semi_supervised.py (limited to 'ot') diff --git a/examples/da/plot_otda_semi_supervised.py b/examples/da/plot_otda_semi_supervised.py new file mode 100644 index 0000000..6e6296b --- /dev/null +++ b/examples/da/plot_otda_semi_supervised.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +""" +============================================ +OTDA unsupervised vs semi-supervised setting +============================================ + +This example introduces a semi supervised domain adaptation in a 2D setting. +It explicits the problem of semi supervised domain adaptation and introduces +some optimal transport approaches to solve it. + +Quantities such as optimal couplings, greater coupling coefficients and +transported samples are represented in order to give a visual understanding +of what the transport methods are doing. +""" + +# Authors: Remi Flamary +# Stanislas Chambon +# +# License: MIT License + +import matplotlib.pylab as pl +import ot + + +############################################################################## +# generate data +############################################################################## + +n_samples_source = 150 +n_samples_target = 150 + +Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source) +Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target) + +# Cost matrix +M = ot.dist(Xs, Xt, metric='sqeuclidean') + + +############################################################################## +# Transport source samples onto target samples +############################################################################## + +# unsupervised domain adaptation +ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1) +ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt) +transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs) + +# semi-supervised domain adaptation +ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1) +ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt) +transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs) + +# semi supervised DA uses available labaled target samples to modify the cost +# matrix involved in the OT problem. The cost of transporting a source sample +# of class A onto a target sample of class B != A is set to infinite, or a +# very large value + + +############################################################################## +# Fig 1 : plots source and target samples + matrix of pairwise distance +############################################################################## + +pl.figure(1, figsize=(10, 10)) +pl.subplot(2, 2, 1) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.xticks([]) +pl.yticks([]) +pl.legend(loc=0) +pl.title('Source samples') + +pl.subplot(2, 2, 2) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.xticks([]) +pl.yticks([]) +pl.legend(loc=0) +pl.title('Target samples') + +pl.subplot(2, 2, 3) +pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest') +pl.xticks([]) +pl.yticks([]) +pl.title('Cost matrix - unsupervised DA') + +pl.subplot(2, 2, 4) +pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest') +pl.xticks([]) +pl.yticks([]) +pl.title('Cost matrix - semisupervised DA') + +pl.tight_layout() + +# the optimal coupling in the semi-supervised DA case will exhibit " shape +# similar" to the cost matrix, (block diagonal matrix) + +############################################################################## +# Fig 2 : plots optimal couplings for the different methods +############################################################################## + +pl.figure(2, figsize=(8, 4)) + +pl.subplot(1, 2, 1) +pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest') +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nUnsupervised DA') + +pl.subplot(1, 2, 2) +pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest') +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nSemi-supervised DA') + +pl.tight_layout() + + +############################################################################## +# Fig 3 : plot transported samples +############################################################################## + +# display transported samples +pl.figure(4, figsize=(8, 4)) +pl.subplot(1, 2, 1) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.5) +pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.title('Transported samples\nEmdTransport') +pl.legend(loc=0) +pl.xticks([]) +pl.yticks([]) + +pl.subplot(1, 2, 2) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.5) +pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.title('Transported samples\nSinkhornTransport') +pl.xticks([]) +pl.yticks([]) + +pl.tight_layout() +pl.show() diff --git a/ot/da.py b/ot/da.py index e694668..1d3d0ba 100644 --- a/ot/da.py +++ b/ot/da.py @@ -966,8 +966,12 @@ class BaseTransport(BaseEstimator): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label Returns ------- @@ -1023,8 +1027,12 @@ class BaseTransport(BaseEstimator): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label Returns ------- @@ -1045,8 +1053,12 @@ class BaseTransport(BaseEstimator): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label batch_size : int, optional (default=128) The batch size for out of sample inverse transform @@ -1110,8 +1122,12 @@ class BaseTransport(BaseEstimator): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label batch_size : int, optional (default=128) The batch size for out of sample inverse transform @@ -1241,8 +1257,12 @@ class SinkhornTransport(BaseTransport): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label Returns ------- @@ -1333,8 +1353,12 @@ class EMDTransport(BaseTransport): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label Returns ------- @@ -1434,8 +1458,12 @@ class SinkhornLpl1Transport(BaseTransport): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label Returns ------- @@ -1545,8 +1573,12 @@ class SinkhornL1l2Transport(BaseTransport): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label Returns ------- @@ -1662,8 +1694,12 @@ class MappingTransport(BaseEstimator): The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape (n_labeled_target_samples,) - The class labels + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label Returns ------- -- cgit v1.2.3 From 185eb3e2ef34b5ce6b8f90a28a5bcc78432b7fd3 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Tue, 5 Sep 2017 15:10:44 +0900 Subject: Removed prints --- ot/lp/emd_wrap.pyx | 2 -- 1 file changed, 2 deletions(-) (limited to 'ot') diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 435a270..4febb32 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -68,8 +68,6 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 - print alpha.size - print beta.size # calling the function EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, maxiter) -- cgit v1.2.3 From 0bb8ec8bf8061aa7ad2299b04b8368b46b56be41 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Tue, 5 Sep 2017 15:36:03 +0900 Subject: Removed declaration of unused variable --- ot/lp/EMD_wrapper.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 8ac43c7..8e74462 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -18,8 +18,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, int max_iter) { // beware M and C anre strored in row major C style!!! - int n, m, i, cur; - double max; + int n, m, i, cur; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); -- cgit v1.2.3 From 3baa34b5504dfbccd6800b59f1f3830a7edf3f20 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Tue, 5 Sep 2017 16:53:55 +0900 Subject: Added include cstdio --- ot/lp/network_simplex_simple.h | 1 + 1 file changed, 1 insertion(+) (limited to 'ot') diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 08449f6..a7743ee 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -49,6 +49,7 @@ #include #include #include +#include #ifdef HASHMAP #include #else -- cgit v1.2.3 From d52b4ea415d9bb669be04ccd0940f9b3d258d0e1 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Tue, 5 Sep 2017 17:15:45 +0900 Subject: Fixed typo and merged emd tests --- ot/lp/__init__.py | 2 +- test/test_emd.py | 68 ------------------------------------------------------- test/test_ot.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 64 insertions(+), 72 deletions(-) delete mode 100644 test/test_emd.py (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index a14d4e4..6048f60 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -168,6 +168,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] def f(b): - return emd2_c(a,b,M, max_iter)[0] + return emd2_c(a,b,M, numItermax)[0] res= parmap(f, [b[:,i] for i in range(nb)],processes) return np.array(res) diff --git a/test/test_emd.py b/test/test_emd.py deleted file mode 100644 index 0025583..0000000 --- a/test/test_emd.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python2 -# -*- coding: utf-8 -*- - -import numpy as np -import ot - -from ot.datasets import get_1D_gauss as gauss -reload(ot.lp) - -#%% parameters - -n=5000 # nb bins -m=6000 # nb bins - -mean1 = 1000 -mean2 = 1100 - -# bin positions -x=np.arange(n,dtype=np.float64) -y=np.arange(m,dtype=np.float64) - -# Gaussian distributions -a=gauss(n,m=mean1,s=5) # m= mean, s= std - -b=gauss(m,m=mean2,s=10) - -# loss matrix -M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2) -#M/=M.max() - -#%% - -print('Computing {} EMD '.format(1)) - -# emd loss 1 proc -ot.tic() -G, alpha, beta = ot.emd(a,b,M, dual_variables=True) -ot.toc('1 proc : {} s') - -cost1 = (G * M).sum() -cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) - -# emd loss 1 proc -ot.tic() -cost_emd2 = ot.emd2(a,b,M) -ot.toc('1 proc : {} s') - -ot.tic() -G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) -ot.toc('1 proc : {} s') - -cost2 = (G2 * M.T).sum() - -M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1) - -# Check that both cost computations are equivalent -np.testing.assert_almost_equal(cost1, cost_emd2) -# Check that dual and primal cost are equal -np.testing.assert_almost_equal(cost1, cost_dual) -# Check symmetry -np.testing.assert_almost_equal(cost1, cost2) -# Check with closed-form solution for gaussians -np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2)) - -[ind1, ind2] = np.nonzero(G) - -# Check that reduced cost is zero on transport arcs -np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) \ No newline at end of file diff --git a/test/test_ot.py b/test/test_ot.py index acd8718..ded6c9f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -6,6 +6,8 @@ import numpy as np import ot + +from ot.datasets import get_1D_gauss as gauss def test_doctest(): @@ -66,9 +68,6 @@ def test_emd_empty(): def test_emd2_multi(): - - from ot.datasets import get_1D_gauss as gauss - n = 1000 # nb bins # bin positions @@ -100,3 +99,64 @@ def test_emd2_multi(): ot.toc('multi proc : {} s') np.testing.assert_allclose(emd1, emdn) + +def test_dual_variables(): + #%% parameters + + n=5000 # nb bins + m=6000 # nb bins + + mean1 = 1000 + mean2 = 1100 + + # bin positions + x=np.arange(n,dtype=np.float64) + y=np.arange(m,dtype=np.float64) + + # Gaussian distributions + a=gauss(n,m=mean1,s=5) # m= mean, s= std + + b=gauss(m,m=mean2,s=10) + + # loss matrix + M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2) + #M/=M.max() + + #%% + + print('Computing {} EMD '.format(1)) + + # emd loss 1 proc + ot.tic() + G, alpha, beta = ot.emd(a,b,M, dual_variables=True) + ot.toc('1 proc : {} s') + + cost1 = (G * M).sum() + cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) + + # emd loss 1 proc + ot.tic() + cost_emd2 = ot.emd2(a,b,M) + ot.toc('1 proc : {} s') + + ot.tic() + G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) + ot.toc('1 proc : {} s') + + cost2 = (G2 * M.T).sum() + + M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1) + + # Check that both cost computations are equivalent + np.testing.assert_almost_equal(cost1, cost_emd2) + # Check that dual and primal cost are equal + np.testing.assert_almost_equal(cost1, cost_dual) + # Check symmetry + np.testing.assert_almost_equal(cost1, cost2) + # Check with closed-form solution for gaussians + np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2)) + + [ind1, ind2] = np.nonzero(G) + + # Check that reduced cost is zero on transport arcs + np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) -- cgit v1.2.3 From f8c1c8740f9974dcf4aaf191851d62149dceb91c Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Thu, 7 Sep 2017 13:29:46 +0900 Subject: Added MAX_ITER_REACHED flag and warning --- ot/lp/EMD.h | 3 ++- ot/lp/EMD_wrapper.cpp | 21 +++++++---------- ot/lp/emd_wrap.pyx | 29 +++++++++++++---------- ot/lp/network_simplex_simple.h | 46 ++++++++++++++++++++----------------- test/test_ot.py | 52 ++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 102 insertions(+), 49 deletions(-) (limited to 'ot') diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index 15e9115..bb486de 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -26,7 +26,8 @@ typedef unsigned int node_id_type; enum ProblemType { INFEASIBLE, OPTIMAL, - UNBOUNDED + UNBOUNDED, + MAX_ITER_REACHED }; int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter); diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 8e74462..92663dc 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -29,14 +29,18 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double val=*(X+i); if (val>0) { n++; - } + }else if(val<0){ + return INFEASIBLE; + } } m=0; for (int i=0; i0) { m++; - } + }else if(val<0){ + return INFEASIBLE; + } } // Define the graph @@ -83,16 +87,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, // Solve the problem with the network simplex algorithm int ret=net.run(); - if (ret!=(int)net.OPTIMAL) { - if (ret==(int)net.INFEASIBLE) { - std::cout << "Infeasible problem"; - } - if (ret==(int)net.UNBOUNDED) - { - std::cout << "Unbounded problem"; - } - } else - { + if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { *cost = 0; Arc a; di.first(a); for (; a != INVALID; di.next(a)) { @@ -105,7 +100,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, *(beta + indJ[j-n]) = net.potential(j); } - }; + } return ret; diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 7056e0e..9bea154 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -7,6 +7,7 @@ Cython linker with C solver # # License: MIT License +import warnings import numpy as np cimport numpy as np @@ -15,14 +16,14 @@ cimport cython cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter) - cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int numItermax) + cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @cython.boundscheck(False) @cython.wraparound(False) -def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -49,7 +50,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod target histogram M : (ns,nt) ndarray, float64 loss matrix - max_iter : int + numItermax : int The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -76,18 +77,20 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - cdef int resultSolver = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) + cdef int resultSolver = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, numItermax) if resultSolver != OPTIMAL: if resultSolver == INFEASIBLE: - print("Problem infeasible. Try to increase numItermax.") + warnings.warn("Problem infeasible. Check that a and b are in the simplex") elif resultSolver == UNBOUNDED: - print("Problem unbounded") + warnings.warn("Problem unbounded") + elif resultSolver == MAX_ITER_REACHED: + warnings.warn("numItermax reached before optimality. Try to increase numItermax.") return G, alpha, beta @cython.boundscheck(False) @cython.wraparound(False) -def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax): """ Solves the Earth Movers distance problem and returns the optimal transport loss @@ -114,7 +117,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo target histogram M : (ns,nt) ndarray, float64 loss matrix - max_iter : int + numItermax : int The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -140,12 +143,14 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo if not len(b): b=np.ones((n2,))/n2 # calling the function - cdef int resultSolver = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) + cdef int resultSolver = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, numItermax) if resultSolver != OPTIMAL: if resultSolver == INFEASIBLE: - print("Problem infeasible. Try to inscrease numItermax.") + warnings.warn("Problem infeasible. Check that a and b are in the simplex") elif resultSolver == UNBOUNDED: - print("Problem unbounded") + warnings.warn("Problem unbounded") + elif resultSolver == MAX_ITER_REACHED: + warnings.warn("numItermax reached before optimality. Try to increase numItermax.") return cost, alpha, beta diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index a7743ee..7c6a4ce 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -34,7 +34,8 @@ #endif -#define EPSILON 10*2.2204460492503131e-016 +#define EPSILON 2.2204460492503131e-15 +#define _EPSILON 1e-8 #define MAX_DEBUG_ITER 100000 @@ -260,7 +261,9 @@ namespace lemon { /// The objective function of the problem is unbounded, i.e. /// there is a directed cycle having negative total cost and /// infinite upper bound. - UNBOUNDED + UNBOUNDED, + /// The maximum number of iteration has been reached + MAX_ITER_REACHED }; /// \brief Constants for selecting the type of the supply constraints. @@ -683,7 +686,7 @@ namespace lemon { /// \see resetParams(), reset() ProblemType run() { #if DEBUG_LVL>0 - std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "nUNBOUNDED = " << UNBOUNDED << "\n"; + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n"; #endif if (!init()) return INFEASIBLE; @@ -941,15 +944,15 @@ namespace lemon { // Initialize internal data structures bool init() { if (_node_num == 0) return false; - /* + // Check the sum of supply values _sum_supply = 0; for (int i = 0; i != _node_num; ++i) { _sum_supply += _supply[i]; } - if ( !((_stype == GEQ && _sum_supply <= _epsilon ) || - (_stype == LEQ && _sum_supply >= -_epsilon )) ) return false; - */ + if ( fabs(_sum_supply) > _EPSILON ) return false; + + _sum_supply = 0; // Initialize artifical cost Cost ART_COST; @@ -1416,13 +1419,11 @@ namespace lemon { ProblemType start() { PivotRuleImpl pivot(*this); double prevCost=-1; + ProblemType retVal = OPTIMAL; // Perform heuristic initial pivots if (!initialPivots()) return UNBOUNDED; -#if DEBUG_LVL>0 - int niter=0; -#endif int iter_number=0; //pivot.setDantzig(true); // Execute the Network Simplex algorithm @@ -1431,12 +1432,13 @@ namespace lemon { char errMess[1000]; sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); std::cerr << errMess; + retVal = MAX_ITER_REACHED; break; } #if DEBUG_LVL>0 - if(niter>MAX_DEBUG_ITER) + if(iter_number>MAX_DEBUG_ITER) break; - if(++niter%1000==0||niter%1000==1){ + if(iter_number%1000==0||iter_number%1000==1){ double curCost=totalCost(); double sumFlow=0; double a; @@ -1445,7 +1447,7 @@ namespace lemon { for (int i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; std::cout << _cost[in_arc] << "\n"; @@ -1503,15 +1505,17 @@ namespace lemon { std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; #endif // Check feasibility - for (int e = _search_arc_num; e != _all_arc_num; ++e) { - if (_flow[e] != 0){ - if (abs(_flow[e]) > EPSILON) - return INFEASIBLE; - else - _flow[e]=0; + if( retVal == OPTIMAL){ + for (int e = _search_arc_num; e != _all_arc_num; ++e) { + if (_flow[e] != 0){ + if (abs(_flow[e]) > EPSILON) + return INFEASIBLE; + else + _flow[e]=0; + } } - } + } // Shift potentials to meet the requirements of the GEQ/LEQ type // optimality conditions @@ -1537,7 +1541,7 @@ namespace lemon { } } - return OPTIMAL; + return retVal; } }; //class NetworkSimplexSimple diff --git a/test/test_ot.py b/test/test_ot.py index 6f0f7c9..8a19cf6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -8,6 +8,7 @@ import numpy as np import ot from ot.datasets import get_1D_gauss as gauss +import warnings def test_doctest(): @@ -100,9 +101,56 @@ def test_emd2_multi(): np.testing.assert_allclose(emd1, emdn) -def test_dual_variables(): - # %% parameters +def test_warnings(): + n = 100 # nb bins + m = 100 # nb bins + + mean1 = 30 + mean2 = 50 + + # bin positions + x = np.arange(n, dtype=np.float64) + y = np.arange(m, dtype=np.float64) + + # Gaussian distributions + a = gauss(n, m=mean1, s=5) # m= mean, s= std + + b = gauss(m, m=mean2, s=10) + # loss matrix + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) + # M/=M.max() + + # %% + + print('Computing {} EMD '.format(1)) + G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Trigger a warning. + print('Computing {} EMD '.format(1)) + G, alpha, beta = ot.emd(a, b, M, dual_variables=True, numItermax=1) + # Verify some things + assert "numItermax" in str(w[-1].message) + assert len(w) == 1 + # Trigger a warning. + a[0]=100 + print('Computing {} EMD '.format(2)) + G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + # Verify some things + assert "infeasible" in str(w[-1].message) + assert len(w) == 2 + # Trigger a warning. + a[0]=-1 + print('Computing {} EMD '.format(2)) + G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + # Verify some things + assert "infeasible" in str(w[-1].message) + assert len(w) == 3 + + +def test_dual_variables(): n = 5000 # nb bins m = 6000 # nb bins -- cgit v1.2.3 From 12d9b3ff72e9669ccc0162e82b7a33beb51d3e25 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Thu, 7 Sep 2017 13:50:41 +0900 Subject: Return dual variables in an optional dictionary Also removed some code duplication --- ot/lp/__init__.py | 24 +++++++++++++------ ot/lp/emd_wrap.pyx | 69 +----------------------------------------------------- test/test_ot.py | 20 ++++++---------- 3 files changed, 25 insertions(+), 88 deletions(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 6048f60..c15e6b9 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -9,12 +9,12 @@ Solvers for the original linear program OT problem import numpy as np # import compiled emd -from .emd_wrap import emd_c, emd2_c +from .emd_wrap import emd_c from ..utils import parmap import multiprocessing -def emd(a, b, M, numItermax=100000, dual_variables=False): +def emd(a, b, M, numItermax=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -42,11 +42,17 @@ def emd(a, b, M, numItermax=100000, dual_variables=False): numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost and dual + variables. Otherwise returns only the optimal transportation matrix. Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + If input log is true, a dictionary containing the cost and dual + variables Examples @@ -86,9 +92,13 @@ def emd(a, b, M, numItermax=100000, dual_variables=False): if len(b) == 0: b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - G, alpha, beta = emd_c(a, b, M, numItermax) - if dual_variables: - return G, alpha, beta + G, cost, u, v = emd_c(a, b, M, numItermax) + if log: + log = {} + log['cost'] = cost + log['u'] = u + log['v'] = v + return G, log return G def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): @@ -163,11 +173,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] if len(b.shape)==1: - return emd2_c(a, b, M, numItermax)[0] + return emd_c(a, b, M, numItermax)[1] nb = b.shape[1] # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] def f(b): - return emd2_c(a,b,M, numItermax)[0] + return emd_c(a,b,M, numItermax)[1] res= parmap(f, [b[:,i] for i in range(nb)],processes) return np.array(res) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 9bea154..5618dfc 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -86,71 +86,4 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod elif resultSolver == MAX_ITER_REACHED: warnings.warn("numItermax reached before optimality. Try to increase numItermax.") - return G, alpha, beta - -@cython.boundscheck(False) -@cython.wraparound(False) -def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax): - """ - Solves the Earth Movers distance problem and returns the optimal transport loss - - gamm=emd(a,b,M) - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the metric cost matrix - - a and b are the sample weights - - Parameters - ---------- - a : (ns,) ndarray, float64 - source histogram - b : (nt,) ndarray, float64 - target histogram - M : (ns,nt) ndarray, float64 - loss matrix - numItermax : int - The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - - - Returns - ------- - gamma: (ns x nt) ndarray - Optimal transportation matrix for the given parameters - - """ - cdef int n1= M.shape[0] - cdef int n2= M.shape[1] - - cdef double cost=0 - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) - - cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1]) - cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2]) - - if not len(a): - a=np.ones((n1,))/n1 - - if not len(b): - b=np.ones((n2,))/n2 - # calling the function - cdef int resultSolver = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, numItermax) - if resultSolver != OPTIMAL: - if resultSolver == INFEASIBLE: - warnings.warn("Problem infeasible. Check that a and b are in the simplex") - elif resultSolver == UNBOUNDED: - warnings.warn("Problem unbounded") - elif resultSolver == MAX_ITER_REACHED: - warnings.warn("numItermax reached before optimality. Try to increase numItermax.") - - return cost, alpha, beta - + return G, cost, alpha, beta diff --git a/test/test_ot.py b/test/test_ot.py index 8a19cf6..78f64ab 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -124,27 +124,26 @@ def test_warnings(): # %% print('Computing {} EMD '.format(1)) - G, alpha, beta = ot.emd(a, b, M, dual_variables=True) with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") # Trigger a warning. print('Computing {} EMD '.format(1)) - G, alpha, beta = ot.emd(a, b, M, dual_variables=True, numItermax=1) + G = ot.emd(a, b, M, numItermax=1) # Verify some things assert "numItermax" in str(w[-1].message) assert len(w) == 1 # Trigger a warning. a[0]=100 print('Computing {} EMD '.format(2)) - G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + G = ot.emd(a, b, M) # Verify some things assert "infeasible" in str(w[-1].message) assert len(w) == 2 # Trigger a warning. a[0]=-1 print('Computing {} EMD '.format(2)) - G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + G = ot.emd(a, b, M) # Verify some things assert "infeasible" in str(w[-1].message) assert len(w) == 3 @@ -176,16 +175,11 @@ def test_dual_variables(): # emd loss 1 proc ot.tic() - G, alpha, beta = ot.emd(a, b, M, dual_variables=True) + G, log = ot.emd(a, b, M, log=True) ot.toc('1 proc : {} s') cost1 = (G * M).sum() - cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) - - # emd loss 1 proc - ot.tic() - cost_emd2 = ot.emd2(a, b, M) - ot.toc('1 proc : {} s') + cost_dual = np.vdot(a, log['u']) + np.vdot(b, log['v']) ot.tic() G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) @@ -194,7 +188,7 @@ def test_dual_variables(): cost2 = (G2 * M.T).sum() # Check that both cost computations are equivalent - np.testing.assert_almost_equal(cost1, cost_emd2) + np.testing.assert_almost_equal(cost1, log['cost']) # Check that dual and primal cost are equal np.testing.assert_almost_equal(cost1, cost_dual) # Check symmetry @@ -205,5 +199,5 @@ def test_dual_variables(): [ind1, ind2] = np.nonzero(G) # Check that reduced cost is zero on transport arcs - np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], + np.testing.assert_array_almost_equal((M - log['u'].reshape(-1, 1) - log['v'].reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) -- cgit v1.2.3 From ab65f86304b03a967054eeeaf73b8c8277618d65 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Thu, 7 Sep 2017 14:35:35 +0900 Subject: Added log option to muliprocess emd --- ot/lp/__init__.py | 39 ++++++++++++++++++++++++------------- test/test_ot.py | 57 ++++++++++++++++++++++++++++++------------------------- 2 files changed, 57 insertions(+), 39 deletions(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c15e6b9..8edd8ec 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -7,11 +7,13 @@ Solvers for the original linear program OT problem # # License: MIT License +import multiprocessing + import numpy as np + # import compiled emd from .emd_wrap import emd_c from ..utils import parmap -import multiprocessing def emd(a, b, M, numItermax=100000, log=False): @@ -88,9 +90,9 @@ def emd(a, b, M, numItermax=100000, log=False): # if empty array given then use unifor distributions if len(a) == 0: - a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] G, cost, u, v = emd_c(a, b, M, numItermax) if log: @@ -101,7 +103,8 @@ def emd(a, b, M, numItermax=100000, log=False): return G, log return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): + +def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -168,16 +171,26 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000): # if empty array given then use unifor distributions if len(a) == 0: - a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] - - if len(b.shape)==1: - return emd_c(a, b, M, numItermax)[1] + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + if log: + def f(b): + G, cost, u, v = emd_c(a, b, M, numItermax) + log = {} + log['G'] = G + log['u'] = u + log['v'] = v + return [cost, log] + else: + def f(b): + return emd_c(a, b, M, numItermax)[1] + + if len(b.shape) == 1: + return f(b) nb = b.shape[1] # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] - def f(b): - return emd_c(a,b,M, numItermax)[1] - res= parmap(f, [b[:,i] for i in range(nb)],processes) - return np.array(res) + res = parmap(f, [b[:, i] for i in range(nb)], processes) + return res diff --git a/test/test_ot.py b/test/test_ot.py index 78f64ab..feadef4 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -4,11 +4,12 @@ # # License: MIT License +import warnings + import numpy as np import ot from ot.datasets import get_1D_gauss as gauss -import warnings def test_doctest(): @@ -100,6 +101,21 @@ def test_emd2_multi(): np.testing.assert_allclose(emd1, emdn) + # emd loss multipro proc with log + ot.tic() + emdn = ot.emd2(a, b, M, log=True) + ot.toc('multi proc : {} s') + + for i in range(len(emdn)): + emd = emdn[i] + log = emd[1] + cost = emd[0] + check_duality_gap(a, b[:, i], M, log['G'], log['u'], log['v'], cost) + emdn[i] = cost + + emdn = np.array(emdn) + np.testing.assert_allclose(emd1, emdn) + def test_warnings(): n = 100 # nb bins @@ -119,32 +135,22 @@ def test_warnings(): # loss matrix M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) - # M/=M.max() - - # %% print('Computing {} EMD '.format(1)) with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. warnings.simplefilter("always") - # Trigger a warning. print('Computing {} EMD '.format(1)) G = ot.emd(a, b, M, numItermax=1) - # Verify some things assert "numItermax" in str(w[-1].message) assert len(w) == 1 - # Trigger a warning. - a[0]=100 + a[0] = 100 print('Computing {} EMD '.format(2)) G = ot.emd(a, b, M) - # Verify some things assert "infeasible" in str(w[-1].message) assert len(w) == 2 - # Trigger a warning. - a[0]=-1 + a[0] = -1 print('Computing {} EMD '.format(2)) G = ot.emd(a, b, M) - # Verify some things assert "infeasible" in str(w[-1].message) assert len(w) == 3 @@ -167,9 +173,6 @@ def test_dual_variables(): # loss matrix M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) - # M/=M.max() - - # %% print('Computing {} EMD '.format(1)) @@ -178,26 +181,28 @@ def test_dual_variables(): G, log = ot.emd(a, b, M, log=True) ot.toc('1 proc : {} s') - cost1 = (G * M).sum() - cost_dual = np.vdot(a, log['u']) + np.vdot(b, log['v']) - ot.tic() G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) ot.toc('1 proc : {} s') - cost2 = (G2 * M.T).sum() + cost1 = (G * M).sum() + # Check symmetry + np.testing.assert_array_almost_equal(cost1, (M * G2.T).sum()) + # Check with closed-form solution for gaussians + np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2)) # Check that both cost computations are equivalent np.testing.assert_almost_equal(cost1, log['cost']) + check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) + + +def check_duality_gap(a, b, M, G, u, v, cost): + cost_dual = np.vdot(a, u) + np.vdot(b, v) # Check that dual and primal cost are equal - np.testing.assert_almost_equal(cost1, cost_dual) - # Check symmetry - np.testing.assert_almost_equal(cost1, cost2) - # Check with closed-form solution for gaussians - np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2)) + np.testing.assert_almost_equal(cost_dual, cost) [ind1, ind2] = np.nonzero(G) # Check that reduced cost is zero on transport arcs - np.testing.assert_array_almost_equal((M - log['u'].reshape(-1, 1) - log['v'].reshape(1, -1))[ind1, ind2], + np.testing.assert_array_almost_equal((M - u.reshape(-1, 1) - v.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) -- cgit v1.2.3 From e58cd780ccf87736265e4e1a39afa3a167325ccc Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 12:37:56 +0900 Subject: Added convergence status to the log --- ot/lp/__init__.py | 16 ++++++++++++---- ot/lp/emd_wrap.pyx | 28 +++++++++++++++++----------- 2 files changed, 29 insertions(+), 15 deletions(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 8edd8ec..0f40c19 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,7 +12,7 @@ import multiprocessing import numpy as np # import compiled emd -from .emd_wrap import emd_c +from .emd_wrap import emd_c, checkResult from ..utils import parmap @@ -94,12 +94,15 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v = emd_c(a, b, M, numItermax) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + resultCodeString = checkResult(resultCode) if log: log = {} log['cost'] = cost log['u'] = u log['v'] = v + log['warning'] = resultCodeString + log['resultCode'] = resultCode return G, log return G @@ -177,15 +180,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= if log: def f(b): - G, cost, u, v = emd_c(a, b, M, numItermax) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + resultCodeString = checkResult(resultCode) log = {} log['G'] = G log['u'] = u log['v'] = v + log['warning'] = resultCodeString + log['resultCode'] = resultCode return [cost, log] else: def f(b): - return emd_c(a, b, M, numItermax)[1] + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) + checkResult(resultCode) + return cost if len(b.shape) == 1: return f(b) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 5618dfc..19bcdd8 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -7,12 +7,12 @@ Cython linker with C solver # # License: MIT License -import warnings import numpy as np cimport numpy as np cimport cython +import warnings cdef extern from "EMD.h": @@ -20,6 +20,19 @@ cdef extern from "EMD.h": cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED +def checkResult(resultCode): + if resultCode == OPTIMAL: + return None + + if resultCode == INFEASIBLE: + message = "Problem infeasible. Check that a and b are in the simplex" + elif resultCode == UNBOUNDED: + message = "Problem unbounded" + elif resultCode == MAX_ITER_REACHED: + message = "numItermax reached before optimality. Try to increase numItermax." + warnings.warn(message) + return message + @cython.boundscheck(False) @cython.wraparound(False) @@ -77,13 +90,6 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - cdef int resultSolver = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, numItermax) - if resultSolver != OPTIMAL: - if resultSolver == INFEASIBLE: - warnings.warn("Problem infeasible. Check that a and b are in the simplex") - elif resultSolver == UNBOUNDED: - warnings.warn("Problem unbounded") - elif resultSolver == MAX_ITER_REACHED: - warnings.warn("numItermax reached before optimality. Try to increase numItermax.") - - return G, cost, alpha, beta + cdef int resultCode = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, numItermax) + + return G, cost, alpha, beta, resultCode -- cgit v1.2.3 From 85c56d96f609c4ad458f0963a068386cc910c66c Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 17:28:38 +0900 Subject: Renamed variables --- ot/da.py | 2 +- ot/lp/__init__.py | 31 +++++++++++++++++-------------- ot/lp/emd_wrap.pyx | 18 +++++++++--------- test/test_ot.py | 2 +- 4 files changed, 28 insertions(+), 25 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 1d3d0ba..eb70305 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1370,7 +1370,7 @@ class EMDTransport(BaseTransport): # coupling estimation self.coupling_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter + a=self.mu_s, b=self.mu_t, M=self.cost_, num_iter_max=self.max_iter ) return self diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 0f40c19..ab7cb97 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -12,11 +12,11 @@ import multiprocessing import numpy as np # import compiled emd -from .emd_wrap import emd_c, checkResult +from .emd_wrap import emd_c, check_result from ..utils import parmap -def emd(a, b, M, numItermax=100000, log=False): +def emd(a, b, M, num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -41,7 +41,7 @@ def emd(a, b, M, numItermax=100000, log=False): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - numItermax : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -54,7 +54,7 @@ def emd(a, b, M, numItermax=100000, log=False): Optimal transportation matrix for the given parameters log: dict If input log is true, a dictionary containing the cost and dual - variables + variables and exit status Examples @@ -94,20 +94,20 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - resultCodeString = checkResult(resultCode) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) + resultCodeString = check_result(result_code) if log: log = {} log['cost'] = cost log['u'] = u log['v'] = v log['warning'] = resultCodeString - log['resultCode'] = resultCode + log['result_code'] = result_code return G, log return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -131,7 +131,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - numItermax : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -139,6 +139,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + If input log is true, a dictionary containing the cost and dual + variables and exit status Examples @@ -180,19 +183,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log= if log: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - resultCodeString = checkResult(resultCode) + G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) + resultCodeString = check_result(resultCode) log = {} log['G'] = G log['u'] = u log['v'] = v log['warning'] = resultCodeString - log['resultCode'] = resultCode + log['result_code'] = resultCode return [cost, log] else: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - checkResult(resultCode) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) + check_result(result_code) return cost if len(b.shape) == 1: diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 19bcdd8..7ebdd2a 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -20,15 +20,15 @@ cdef extern from "EMD.h": cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED -def checkResult(resultCode): - if resultCode == OPTIMAL: +def check_result(result_code): + if result_code == OPTIMAL: return None - if resultCode == INFEASIBLE: + if result_code == INFEASIBLE: message = "Problem infeasible. Check that a and b are in the simplex" - elif resultCode == UNBOUNDED: + elif result_code == UNBOUNDED: message = "Problem unbounded" - elif resultCode == MAX_ITER_REACHED: + elif result_code == MAX_ITER_REACHED: message = "numItermax reached before optimality. Try to increase numItermax." warnings.warn(message) return message @@ -36,7 +36,7 @@ def checkResult(resultCode): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int num_iter_max): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -63,7 +63,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod target histogram M : (ns,nt) ndarray, float64 loss matrix - numItermax : int + num_iter_max : int The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -90,6 +90,6 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - cdef int resultCode = EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, numItermax) + cdef int result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, num_iter_max) - return G, cost, alpha, beta, resultCode + return G, cost, alpha, beta, result_code diff --git a/test/test_ot.py b/test/test_ot.py index cf5839e..c9b5154 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -140,7 +140,7 @@ def test_warnings(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") print('Computing {} EMD '.format(1)) - ot.emd(a, b, M, numItermax=1) + ot.emd(a, b, M, num_iter_max=1) assert "numItermax" in str(w[-1].message) assert len(w) == 1 a[0] = 100 -- cgit v1.2.3 From 1ba2c837d54ce963ad63ddf8df2e47230800b747 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 17:30:23 +0900 Subject: Renamed variables --- ot/lp/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index ab7cb97..1238cdb 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -95,13 +95,13 @@ def emd(a, b, M, num_iter_max=100000, log=False): b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) - resultCodeString = check_result(result_code) + result_code_string = check_result(result_code) if log: log = {} log['cost'] = cost log['u'] = u log['v'] = v - log['warning'] = resultCodeString + log['warning'] = result_code_string log['result_code'] = result_code return G, log return G @@ -184,12 +184,12 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo if log: def f(b): G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) - resultCodeString = check_result(resultCode) + result_code_string = check_result(resultCode) log = {} log['G'] = G log['u'] = u log['v'] = v - log['warning'] = resultCodeString + log['warning'] = result_code_string log['result_code'] = resultCode return [cost, log] else: -- cgit v1.2.3 From cd8c04246b6d1f15b68d6433741e8c808fd517d8 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 17:38:31 +0900 Subject: Renamed variable --- ot/da.py | 2 +- ot/lp/EMD.h | 2 +- ot/lp/EMD_wrapper.cpp | 4 ++-- ot/lp/__init__.py | 14 +++++++------- ot/lp/emd_wrap.pyx | 8 ++++---- test/test_ot.py | 2 +- 6 files changed, 16 insertions(+), 16 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index eb70305..f3e7433 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1370,7 +1370,7 @@ class EMDTransport(BaseTransport): # coupling estimation self.coupling_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, num_iter_max=self.max_iter + a=self.mu_s, b=self.mu_t, M=self.cost_, max_iter=self.max_iter ) return self diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index bb486de..f42e222 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -30,6 +30,6 @@ enum ProblemType { MAX_ITER_REACHED }; -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter); +int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 92663dc..fc7ca63 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -16,7 +16,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, - double* alpha, double* beta, double *cost, int max_iter) { + double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! int n, m, i, cur; @@ -48,7 +48,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, std::vector indI(n), indJ(m); std::vector weights1(n), weights2(m); Digraph di(n, m); - NetworkSimplexSimple net(di, true, n+m, n*m, max_iter); + NetworkSimplexSimple net(di, true, n+m, n*m, maxIter); // Set supply and demand, don't account for 0 values (faster) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 1238cdb..9a0cb1c 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -16,7 +16,7 @@ from .emd_wrap import emd_c, check_result from ..utils import parmap -def emd(a, b, M, num_iter_max=100000, log=False): +def emd(a, b, M, max_iter=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -41,7 +41,7 @@ def emd(a, b, M, num_iter_max=100000, log=False): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - num_iter_max : int, optional (default=100000) + max_iter : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -94,7 +94,7 @@ def emd(a, b, M, num_iter_max=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) + G, cost, u, v, result_code = emd_c(a, b, M, max_iter) result_code_string = check_result(result_code) if log: log = {} @@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -131,7 +131,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - num_iter_max : int, optional (default=100000) + max_iter : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo if log: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) + G, cost, u, v, resultCode = emd_c(a, b, M, max_iter) result_code_string = check_result(resultCode) log = {} log['G'] = G @@ -194,7 +194,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) + G, cost, u, v, result_code = emd_c(a, b, M, max_iter) check_result(result_code) return cost diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 7ebdd2a..83ee6aa 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -16,7 +16,7 @@ import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int numItermax) + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -36,7 +36,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int num_iter_max): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -63,7 +63,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod target histogram M : (ns,nt) ndarray, float64 loss matrix - num_iter_max : int + max_iter : int The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -90,6 +90,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod b=np.ones((n2,))/n2 # calling the function - cdef int result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, num_iter_max) + cdef int result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) return G, cost, alpha, beta, result_code diff --git a/test/test_ot.py b/test/test_ot.py index c9b5154..ca921c5 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -140,7 +140,7 @@ def test_warnings(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") print('Computing {} EMD '.format(1)) - ot.emd(a, b, M, num_iter_max=1) + ot.emd(a, b, M, max_iter=1) assert "numItermax" in str(w[-1].message) assert len(w) == 1 a[0] = 100 -- cgit v1.2.3 From 8cc04ef5ae8806c81811b2081b1880b46ca063a3 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 18:05:12 +0900 Subject: Renamed variable in string --- ot/lp/__init__.py | 1 - ot/lp/emd_wrap.pyx | 2 +- test/test_ot.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 9a0cb1c..ae5b08c 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -201,7 +201,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa if len(b.shape) == 1: return f(b) nb = b.shape[1] - # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)] res = parmap(f, [b[:, i] for i in range(nb)], processes) return res diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 83ee6aa..2fcc0e4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -29,7 +29,7 @@ def check_result(result_code): elif result_code == UNBOUNDED: message = "Problem unbounded" elif result_code == MAX_ITER_REACHED: - message = "numItermax reached before optimality. Try to increase numItermax." + message = "max_iter reached before optimality. Try to increase max_iter." warnings.warn(message) return message diff --git a/test/test_ot.py b/test/test_ot.py index ca921c5..46fc634 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -141,7 +141,7 @@ def test_warnings(): warnings.simplefilter("always") print('Computing {} EMD '.format(1)) ot.emd(a, b, M, max_iter=1) - assert "numItermax" in str(w[-1].message) + assert "max_iter" in str(w[-1].message) assert len(w) == 1 a[0] = 100 print('Computing {} EMD '.format(2)) -- cgit v1.2.3 From 06429e5a34790ec51eb1c921293b24c37b81b952 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 18:23:05 +0900 Subject: Returned to old variable name to follow repo convention --- ot/da.py | 2 +- ot/lp/__init__.py | 12 ++++++------ ot/lp/emd_wrap.pyx | 2 +- test/test_ot.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index f3e7433..eb70305 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1370,7 +1370,7 @@ class EMDTransport(BaseTransport): # coupling estimation self.coupling_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, max_iter=self.max_iter + a=self.mu_s, b=self.mu_t, M=self.cost_, num_iter_max=self.max_iter ) return self diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index ae5b08c..17f5bb4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -16,7 +16,7 @@ from .emd_wrap import emd_c, check_result from ..utils import parmap -def emd(a, b, M, max_iter=100000, log=False): +def emd(a, b, M, num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -41,7 +41,7 @@ def emd(a, b, M, max_iter=100000, log=False): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - max_iter : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -94,7 +94,7 @@ def emd(a, b, M, max_iter=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, result_code = emd_c(a, b, M, max_iter) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) result_code_string = check_result(result_code) if log: log = {} @@ -107,7 +107,7 @@ def emd(a, b, M, max_iter=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa if log: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, max_iter) + G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) result_code_string = check_result(resultCode) log = {} log['G'] = G @@ -194,7 +194,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, max_iter) + G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) check_result(result_code) return cost diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2fcc0e4..45fc988 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -29,7 +29,7 @@ def check_result(result_code): elif result_code == UNBOUNDED: message = "Problem unbounded" elif result_code == MAX_ITER_REACHED: - message = "max_iter reached before optimality. Try to increase max_iter." + message = "num_iter_max reached before optimality. Try to increase num_iter_max." warnings.warn(message) return message diff --git a/test/test_ot.py b/test/test_ot.py index 46fc634..e05e8aa 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -140,8 +140,8 @@ def test_warnings(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") print('Computing {} EMD '.format(1)) - ot.emd(a, b, M, max_iter=1) - assert "max_iter" in str(w[-1].message) + ot.emd(a, b, M, num_iter_max=1) + assert "num_iter_max" in str(w[-1].message) assert len(w) == 1 a[0] = 100 print('Computing {} EMD '.format(2)) -- cgit v1.2.3 From 7c6169222979a7e82a83c118bc7117684258d0de Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Sat, 9 Sep 2017 18:29:32 +0900 Subject: Updated variable name in docstring --- ot/lp/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 17f5bb4..f2eaa2b 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -131,7 +131,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - max_iter : int, optional (default=100000) + num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. -- cgit v1.2.3 From dd6f8260d01ce173ef3fe0c900112f0ed5288950 Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Tue, 12 Sep 2017 19:58:46 +0900 Subject: Made the return of the matrix optional in emd2 --- ot/lp/__init__.py | 12 +++++++++--- test/test_ot.py | 6 +++--- 2 files changed, 12 insertions(+), 6 deletions(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index f2eaa2b..d0f682b 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False, return_matrix=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -134,6 +134,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo num_iter_max : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost and dual + variables. Otherwise returns only the optimal transportation cost. + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log. Returns ------- @@ -181,12 +186,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - if log: + if log or return_matrix: def f(b): G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) result_code_string = check_result(resultCode) log = {} - log['G'] = G + if return_matrix: + log['G'] = G log['u'] = u log['v'] = v log['warning'] = result_code_string diff --git a/test/test_ot.py b/test/test_ot.py index e05e8aa..ea6d9dc 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -103,7 +103,7 @@ def test_emd2_multi(): # emd loss multipro proc with log ot.tic() - emdn = ot.emd2(a, b, M, log=True) + emdn = ot.emd2(a, b, M, log=True, return_matrix=True) ot.toc('multi proc : {} s') for i in range(len(emdn)): @@ -140,8 +140,8 @@ def test_warnings(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") print('Computing {} EMD '.format(1)) - ot.emd(a, b, M, num_iter_max=1) - assert "num_iter_max" in str(w[-1].message) + ot.emd(a, b, M, numItermax=1) + assert "numItermax" in str(w[-1].message) assert len(w) == 1 a[0] = 100 print('Computing {} EMD '.format(2)) -- cgit v1.2.3 From e52b6eb41228a7f8e381cf73c06e0dffba5773be Mon Sep 17 00:00:00 2001 From: Antoine Rolet Date: Tue, 12 Sep 2017 20:00:14 +0900 Subject: Renaming --- ot/da.py | 2 +- ot/lp/__init__.py | 14 +++++++------- ot/lp/emd_wrap.pyx | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index eb70305..1d3d0ba 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1370,7 +1370,7 @@ class EMDTransport(BaseTransport): # coupling estimation self.coupling_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, num_iter_max=self.max_iter + a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter ) return self diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d0f682b..5c09da2 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -16,7 +16,7 @@ from .emd_wrap import emd_c, check_result from ..utils import parmap -def emd(a, b, M, num_iter_max=100000, log=False): +def emd(a, b, M, numItermax=100000, log=False): """Solves the Earth Movers distance problem and returns the OT matrix @@ -41,7 +41,7 @@ def emd(a, b, M, num_iter_max=100000, log=False): Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - num_iter_max : int, optional (default=100000) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -94,7 +94,7 @@ def emd(a, b, M, num_iter_max=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) result_code_string = check_result(result_code) if log: log = {} @@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, log=False, return_matrix=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False): """Solves the Earth Movers distance problem and returns the loss .. math:: @@ -131,7 +131,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 loss matrix - num_iter_max : int, optional (default=100000) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) @@ -188,7 +188,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo if log or return_matrix: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, num_iter_max) + G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) result_code_string = check_result(resultCode) log = {} if return_matrix: @@ -200,7 +200,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo return [cost, log] else: def f(b): - G, cost, u, v, result_code = emd_c(a, b, M, num_iter_max) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) check_result(result_code) return cost diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 45fc988..83ee6aa 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -29,7 +29,7 @@ def check_result(result_code): elif result_code == UNBOUNDED: message = "Problem unbounded" elif result_code == MAX_ITER_REACHED: - message = "num_iter_max reached before optimality. Try to increase num_iter_max." + message = "numItermax reached before optimality. Try to increase numItermax." warnings.warn(message) return message -- cgit v1.2.3