summaryrefslogtreecommitdiff
path: root/ot/lp/EMD_wrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/EMD_wrapper.cpp')
-rw-r--r--ot/lp/EMD_wrapper.cpp65
1 files changed, 46 insertions, 19 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 3ca7319..2aa44c1 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -111,23 +111,19 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
long *iG, long *jG, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
// beware M and C anre strored in row major C style!!!
- int n, m, i, cur;
+
+ // Get the number of non zero coordinates for r and c and vectors
+ int n, m, i, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
- std::vector<int> indI(n), indJ(m);
- std::vector<double> weights1(n), weights2(m);
- Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
-
- // Get the number of non zero coordinates for r and c and vectors
+ // Get the number of non zero coordinates for r and c
n=0;
for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
- weights1[ n ] = val;
- indI[n++]=i;
+ n++;
}else if(val<0){
return INFEASIBLE;
}
@@ -136,14 +132,42 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
- weights2[ m ] = -val;
- indJ[m++]=i;
+ m++;
}else if(val<0){
return INFEASIBLE;
}
}
// Define the graph
+
+ std::vector<int> indI(n), indJ(m);
+ std::vector<double> weights1(n), weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+ cur=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ weights1[ cur ] = val;
+ indI[cur++]=i;
+ }
+ }
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ weights2[ cur ] = -val;
+ indJ[cur++]=i;
+ }
+ }
+
+ // Define the graph
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
@@ -166,14 +190,17 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
int i = di.source(a);
int j = di.target(a);
double flow = net.flow(a);
- *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
-
- *(G+cur) = flow;
- *(iG+cur) = indI[i];
- *(jG+cur) = indJ[j];
- *(alpha + indI[i]) = -net.potential(i);
- *(beta + indJ[j-n]) = net.potential(j);
- cur++;
+ if (flow>0)
+ {
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+
+ *(G+cur) = flow;
+ *(iG+cur) = indI[i];
+ *(jG+cur) = indJ[j-n];
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
+ cur++;
+ }
}
}