diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-11-29 09:38:29 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2019-11-29 09:38:29 +0100 |
commit | e92ae6d155a6bed91c474a3e842581f09deceba3 (patch) | |
tree | 8d93d7c43325aa9a48881874bf44900db139ff4c /ot/lp | |
parent | 7a02c69a3791682cc3993f7a20ed6841eef75441 (diff) |
cleanup cpp code and annd emd with sparse Ot matrix
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 95 |
1 files changed, 75 insertions, 20 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index fc7ca63..91110b4 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -17,18 +17,24 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) { -// beware M and C anre strored in row major C style!!! + // beware M and C anre strored in row major C style!!! int n, m, i, cur; typedef FullBipartiteDigraph Digraph; - DIGRAPH_TYPEDEFS(FullBipartiteDigraph); + DIGRAPH_TYPEDEFS(FullBipartiteDigraph); - // Get the number of non zero coordinates for r and c + std::vector<int> indI(n), indJ(m); + std::vector<double> weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); + + // Get the number of non zero coordinates for r and c and vectors n=0; for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { - n++; + weights1[ n ] = val; + indI[n++]=i; }else if(val<0){ return INFEASIBLE; } @@ -37,42 +43,85 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { - m++; + weights2[ m ] = -val; + indJ[m++]=i; }else if(val<0){ return INFEASIBLE; } } // Define the graph + net.supplyMap(&weights1[0], n, &weights2[0], m); + + // Set the cost of each edge + for (int i=0; i<n; i++) { + for (int j=0; j<m; j++) { + double val=*(D+indI[i]*n2+indJ[j]); + net.setCost(di.arcFromId(i*m+j), val); + } + } + + + // Solve the problem with the network simplex algorithm + + int ret=net.run(); + if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { + *cost = 0; + Arc a; di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a); + double flow = net.flow(a); + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + *(G+indI[i]*n2+indJ[j-n]) = flow; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); + } + + } + + + return ret; +} + + +int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, + int *iG, int *jG, double *G, + double* alpha, double* beta, double *cost, int maxIter) { + // beware M and C anre strored in row major C style!!! + int n, m, i, cur; + + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(FullBipartiteDigraph); std::vector<int> indI(n), indJ(m); std::vector<double> weights1(n), weights2(m); Digraph di(n, m); NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); - // Set supply and demand, don't account for 0 values (faster) - - cur=0; + // Get the number of non zero coordinates for r and c and vectors + n=0; for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { - weights1[ cur ] = val; - indI[cur++]=i; - } + weights1[ n ] = val; + indI[n++]=i; + }else if(val<0){ + return INFEASIBLE; + } } - - // Demand is actually negative supply... - - cur=0; + m=0; for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { - weights2[ cur ] = -val; - indJ[cur++]=i; - } + weights2[ m ] = -val; + indJ[m++]=i; + }else if(val<0){ + return INFEASIBLE; + } } - + // Define the graph net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge @@ -90,14 +139,19 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { *cost = 0; Arc a; di.first(a); + cur=0 for (; a != INVALID; di.next(a)) { int i = di.source(a); int j = di.target(a); double flow = net.flow(a); *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); - *(G+indI[i]*n2+indJ[j-n]) = flow; + + *(G+cur) = flow; + *(iG+cur) = i; + *(jG+cur) = j; *(alpha + indI[i]) = -net.potential(i); *(beta + indJ[j-n]) = net.potential(j); + cur++; } } @@ -105,3 +159,4 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, return ret; } + |