diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2017-09-13 08:11:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-13 08:11:00 +0200 |
commit | a53ede95f916a11e2150ab7917820d813c0034bc (patch) | |
tree | 24304d83267d51b962e18722553973bbc75509f2 /ot/lp/EMD_wrapper.cpp | |
parent | 62dcfbfb78a2be24379cd5cdb4aec70d8c4befaa (diff) | |
parent | e52b6eb41228a7f8e381cf73c06e0dffba5773be (diff) |
Merge pull request #29 from arolet/ot_dual_variables
Dual variables in EMD_wrapper
Diffstat (limited to 'ot/lp/EMD_wrapper.cpp')
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 74 |
1 files changed, 31 insertions, 43 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index c8c2eb3..fc7ca63 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,62 +15,60 @@ #include "EMD.h" -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) { +int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! - int n, m, i,cur; - double max; + int n, m, i, cur; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); // Get the number of non zero coordinates for r and c n=0; - for (node_id_type i=0; i<n1; i++) { + for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { n++; - } + }else if(val<0){ + return INFEASIBLE; + } } m=0; - for (node_id_type i=0; i<n2; i++) { + for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { m++; - } + }else if(val<0){ + return INFEASIBLE; + } } - // Define the graph std::vector<int> indI(n), indJ(m); std::vector<double> weights1(n), weights2(m); Digraph di(n, m); - NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); // Set supply and demand, don't account for 0 values (faster) - max=0; cur=0; - for (node_id_type i=0; i<n1; i++) { + for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { - weights1[ di.nodeFromId(cur) ] = val; - max+=val; + weights1[ cur ] = val; indI[cur++]=i; } } // Demand is actually negative supply... - max=0; cur=0; - for (node_id_type i=0; i<n2; i++) { + for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { - weights2[ di.nodeFromId(cur) ] = -val; + weights2[ cur ] = -val; indJ[cur++]=i; - - max-=val; } } @@ -78,14 +76,10 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge - max=0; - for (node_id_type i=0; i<n; i++) { - for (node_id_type j=0; j<m; j++) { + for (int i=0; i<n; i++) { + for (int j=0; j<m; j++) { double val=*(D+indI[i]*n2+indJ[j]); net.setCost(di.arcFromId(i*m+j), val); - if (val>max) { - max=val; - } } } @@ -93,26 +87,20 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c // Solve the problem with the network simplex algorithm int ret=net.run(); - if (ret!=(int)net.OPTIMAL) { - if (ret==(int)net.INFEASIBLE) { - std::cout << "Infeasible problem"; + if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { + *cost = 0; + Arc a; di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a); + double flow = net.flow(a); + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + *(G+indI[i]*n2+indJ[j-n]) = flow; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); } - if (ret==(int)net.UNBOUNDED) - { - std::cout << "Unbounded problem"; - } - } else - { - for (node_id_type i=0; i<n; i++) - { - for (node_id_type j=0; j<m; j++) - { - *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j)); - } - }; - *cost = net.totalCost(); - - }; + + } return ret; |