summaryrefslogtreecommitdiff
path: root/ot/lp/EMD_wrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/EMD_wrapper.cpp')
-rw-r--r--ot/lp/EMD_wrapper.cpp124
1 files changed, 117 insertions, 7 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index bc873ed..2bdc172 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -12,16 +12,22 @@
*
*/
+
+#include "network_simplex_simple.h"
+#include "network_simplex_simple_omp.h"
#include "EMD.h"
+#include <cstdint>
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
- // beware M and C anre strored in row major C style!!!
- int n, m, i, cur;
+ // beware M and C are stored in row major C style!!!
+
+ using namespace lemon;
+ int n, m, cur;
typedef FullBipartiteDigraph Digraph;
- DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
+ DIGRAPH_TYPEDEFS(Digraph);
// Get the number of non zero coordinates for r and c
n=0;
@@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
std::vector<int> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
// Set supply and demand, don't account for 0 values (faster)
@@ -76,10 +82,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
+ int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
- net.setCost(di.arcFromId(i*m+j), val);
+ net.setCost(di.arcFromId(idarc), val);
+ ++idarc;
}
}
@@ -87,12 +95,13 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
+ int i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
for (; a != INVALID; di.next(a)) {
- int i = di.source(a);
- int j = di.target(a);
+ i = di.source(a);
+ j = di.target(a);
double flow = net.flow(a);
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
*(G+indI[i]*n2+indJ[j-n]) = flow;
@@ -106,3 +115,104 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
return ret;
}
+
+
+
+
+
+
+int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
+ double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
+ // beware M and C are stored in row major C style!!!
+
+ using namespace lemon_omp;
+ int n, m, cur;
+
+ typedef FullBipartiteDigraph Digraph;
+ DIGRAPH_TYPEDEFS(Digraph);
+
+ // Get the number of non zero coordinates for r and c
+ n=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ n++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+ m=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ m++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+
+ // Define the graph
+
+ std::vector<int> indI(n), indJ(m);
+ std::vector<double> weights1(n), weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+ cur=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ weights1[ cur ] = val;
+ indI[cur++]=i;
+ }
+ }
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ weights2[ cur ] = -val;
+ indJ[cur++]=i;
+ }
+ }
+
+
+ net.supplyMap(&weights1[0], n, &weights2[0], m);
+
+ // Set the cost of each edge
+ int64_t idarc = 0;
+ for (int i=0; i<n; i++) {
+ for (int j=0; j<m; j++) {
+ double val=*(D+indI[i]*n2+indJ[j]);
+ net.setCost(di.arcFromId(idarc), val);
+ ++idarc;
+ }
+ }
+
+
+ // Solve the problem with the network simplex algorithm
+
+ int ret=net.run();
+ int i, j;
+ if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
+ *cost = 0;
+ Arc a; di.first(a);
+ for (; a != INVALID; di.next(a)) {
+ i = di.source(a);
+ j = di.target(a);
+ double flow = net.flow(a);
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+ *(G+indI[i]*n2+indJ[j-n]) = flow;
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
+ }
+
+ }
+
+
+ return ret;
+}