diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-11-29 09:46:35 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2019-11-29 09:46:35 +0100 |
commit | 3a858dfa1f2795b22d1e2db3cfd94d1eb7831f8d (patch) | |
tree | 0d64f05129f62a0d34a743af2e5b92c4737ba9c8 /ot | |
parent | df0d259ebab268517716d666ae45494b6ba998ea (diff) |
correct bad speedup
Diffstat (limited to 'ot')
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 48 |
1 files changed, 35 insertions, 13 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 29e2303..65fa80f 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -18,23 +18,17 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! - int n, m, i, cur; + int n, m, i, cur; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); - std::vector<int> indI(n), indJ(m); - std::vector<double> weights1(n), weights2(m); - Digraph di(n, m); - NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); - - // Get the number of non zero coordinates for r and c and vectors + // Get the number of non zero coordinates for r and c n=0; for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { - weights1[ n ] = val; - indI[n++]=i; + n++; }else if(val<0){ return INFEASIBLE; } @@ -43,14 +37,42 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { - weights2[ m ] = -val; - indJ[m++]=i; + m++; }else if(val<0){ return INFEASIBLE; } } // Define the graph + + std::vector<int> indI(n), indJ(m); + std::vector<double> weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); + + // Set supply and demand, don't account for 0 values (faster) + + cur=0; + for (int i=0; i<n1; i++) { + double val=*(X+i); + if (val>0) { + weights1[ cur ] = val; + indI[cur++]=i; + } + } + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + weights2[ cur ] = -val; + indJ[cur++]=i; + } + } + + net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge @@ -147,8 +169,8 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); *(G+cur) = flow; - *(iG+cur) = i; - *(jG+cur) = j; + *(iG+cur) = indI[i]; + *(jG+cur) = indJ[j]; *(alpha + indI[i]) = -net.potential(i); *(beta + indJ[j-n]) = net.potential(j); cur++; |