summaryrefslogtreecommitdiff
path: root/ot/lp/EMD_wrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/EMD_wrapper.cpp')
-rw-r--r--ot/lp/EMD_wrapper.cpp74
1 files changed, 31 insertions, 43 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index c8c2eb3..fc7ca63 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -15,62 +15,60 @@
#include "EMD.h"
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) {
+int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
+ double* alpha, double* beta, double *cost, int maxIter) {
// beware M and C anre strored in row major C style!!!
- int n, m, i,cur;
- double max;
+ int n, m, i, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
// Get the number of non zero coordinates for r and c
n=0;
- for (node_id_type i=0; i<n1; i++) {
+ for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
n++;
- }
+ }else if(val<0){
+ return INFEASIBLE;
+ }
}
m=0;
- for (node_id_type i=0; i<n2; i++) {
+ for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
m++;
- }
+ }else if(val<0){
+ return INFEASIBLE;
+ }
}
-
// Define the graph
std::vector<int> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
// Set supply and demand, don't account for 0 values (faster)
- max=0;
cur=0;
- for (node_id_type i=0; i<n1; i++) {
+ for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
- weights1[ di.nodeFromId(cur) ] = val;
- max+=val;
+ weights1[ cur ] = val;
indI[cur++]=i;
}
}
// Demand is actually negative supply...
- max=0;
cur=0;
- for (node_id_type i=0; i<n2; i++) {
+ for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
- weights2[ di.nodeFromId(cur) ] = -val;
+ weights2[ cur ] = -val;
indJ[cur++]=i;
-
- max-=val;
}
}
@@ -78,14 +76,10 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
- max=0;
- for (node_id_type i=0; i<n; i++) {
- for (node_id_type j=0; j<m; j++) {
+ for (int i=0; i<n; i++) {
+ for (int j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(i*m+j), val);
- if (val>max) {
- max=val;
- }
}
}
@@ -93,26 +87,20 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
// Solve the problem with the network simplex algorithm
int ret=net.run();
- if (ret!=(int)net.OPTIMAL) {
- if (ret==(int)net.INFEASIBLE) {
- std::cout << "Infeasible problem";
+ if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
+ *cost = 0;
+ Arc a; di.first(a);
+ for (; a != INVALID; di.next(a)) {
+ int i = di.source(a);
+ int j = di.target(a);
+ double flow = net.flow(a);
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+ *(G+indI[i]*n2+indJ[j-n]) = flow;
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
}
- if (ret==(int)net.UNBOUNDED)
- {
- std::cout << "Unbounded problem";
- }
- } else
- {
- for (node_id_type i=0; i<n; i++)
- {
- for (node_id_type j=0; j<m; j++)
- {
- *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j));
- }
- };
- *cost = net.totalCost();
-
- };
+
+ }
return ret;