diff options
Diffstat (limited to 'ot/lp/EMD_wrapper.cpp')
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 124 |
1 files changed, 117 insertions, 7 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index bc873ed..2bdc172 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -12,16 +12,22 @@ * */ + +#include "network_simplex_simple.h" +#include "network_simplex_simple_omp.h" #include "EMD.h" +#include <cstdint> int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) { - // beware M and C anre strored in row major C style!!! - int n, m, i, cur; + // beware M and C are stored in row major C style!!! + + using namespace lemon; + int n, m, cur; typedef FullBipartiteDigraph Digraph; - DIGRAPH_TYPEDEFS(FullBipartiteDigraph); + DIGRAPH_TYPEDEFS(Digraph); // Get the number of non zero coordinates for r and c n=0; @@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, std::vector<int> indI(n), indJ(m); std::vector<double> weights1(n), weights2(m); Digraph di(n, m); - NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter); // Set supply and demand, don't account for 0 values (faster) @@ -76,10 +82,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge + int64_t idarc = 0; for (int i=0; i<n; i++) { for (int j=0; j<m; j++) { double val=*(D+indI[i]*n2+indJ[j]); - net.setCost(di.arcFromId(i*m+j), val); + net.setCost(di.arcFromId(idarc), val); + ++idarc; } } @@ -87,12 +95,13 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, // Solve the problem with the network simplex algorithm int ret=net.run(); + int i, j; if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { *cost = 0; Arc a; di.first(a); for (; a != INVALID; di.next(a)) { - int i = di.source(a); - int j = di.target(a); + i = di.source(a); + j = di.target(a); double flow = net.flow(a); *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); *(G+indI[i]*n2+indJ[j-n]) = flow; @@ -106,3 +115,104 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, return ret; } + + + + + + +int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, int maxIter, int numThreads) { + // beware M and C are stored in row major C style!!! + + using namespace lemon_omp; + int n, m, cur; + + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + // Get the number of non zero coordinates for r and c + n=0; + for (int i=0; i<n1; i++) { + double val=*(X+i); + if (val>0) { + n++; + }else if(val<0){ + return INFEASIBLE; + } + } + m=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + m++; + }else if(val<0){ + return INFEASIBLE; + } + } + + // Define the graph + + std::vector<int> indI(n), indJ(m); + std::vector<double> weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads); + + // Set supply and demand, don't account for 0 values (faster) + + cur=0; + for (int i=0; i<n1; i++) { + double val=*(X+i); + if (val>0) { + weights1[ cur ] = val; + indI[cur++]=i; + } + } + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + weights2[ cur ] = -val; + indJ[cur++]=i; + } + } + + + net.supplyMap(&weights1[0], n, &weights2[0], m); + + // Set the cost of each edge + int64_t idarc = 0; + for (int i=0; i<n; i++) { + for (int j=0; j<m; j++) { + double val=*(D+indI[i]*n2+indJ[j]); + net.setCost(di.arcFromId(idarc), val); + ++idarc; + } + } + + + // Solve the problem with the network simplex algorithm + + int ret=net.run(); + int i, j; + if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { + *cost = 0; + Arc a; di.first(a); + for (; a != INVALID; di.next(a)) { + i = di.source(a); + j = di.target(a); + double flow = net.flow(a); + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + *(G+indI[i]*n2+indJ[j-n]) = flow; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); + } + + } + + + return ret; +} |