diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-12-03 15:20:16 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2019-12-03 15:20:16 +0100 |
commit | a4afee871d8e9d5db68228d1ed5bf4853eedc294 (patch) | |
tree | 9aed5291d5e472085b0eca3d0453865684bd443d /ot/lp | |
parent | c439e3efb920086154c741b41f65d99165e875d8 (diff) |
first implemntation sparse loss
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/EMD.h | 5 | ||||
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 78 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 4 | ||||
-rw-r--r-- | ot/lp/network_simplex_simple.h | 2 |
4 files changed, 88 insertions, 1 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index 9896091..fc94211 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -36,4 +36,9 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, long *iG, long *jG, double *G, long * nG, double* alpha, double* beta, double *cost, int maxIter); +int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y, + long *iD, long *jD, double *D, long nD, + long *iG, long *jG, double *G, long * nG, + double* alpha, double* beta, double *cost, int maxIter); + #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 9be2cdc..28e4af2 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -210,3 +210,81 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, return ret; } +int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y, + long *iD, long *jD, double *D, long nD, + long *iG, long *jG, double *G, long * nG, + double* alpha, double* beta, double *cost, int maxIter) { + // beware M and C anre strored in row major C style!!! + + // Get the number of non zero coordinates for r and c and vectors + int n, m, cur; + + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(FullBipartiteDigraph); + + n=n1; + m=n2; + + + // Define the graph + + + std::vector<double> weights2(m); + Digraph di(n, m); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); + + // Set supply and demand, don't account for 0 values (faster) + + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + weights2[ cur ] = -val; + } + } + + // Define the graph + net.supplyMap(X, n, &weights2[0], m); + + // Set the cost of each edge + for (int k=0; k<nD; k++) { + int i = iD[k]; + int j = jD[k]; + net.setCost(di.arcFromId(i*m+j), D[k]); + + } + + + // Solve the problem with the network simplex algorithm + + int ret=net.run(); + if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { + *cost = net.totalCost(); + Arc a; di.first(a); + cur=0; + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a); + double flow = net.flow(a); + if (flow>0) + { + + *(G+cur) = flow; + *(iG+cur) = i; + *(jG+cur) = j-n; + *(alpha + i) = -net.potential(i); + *(beta + j-n) = net.potential(j); + cur++; + } + } + *nG=cur; // nb of value +1 for numpy indexing + + } + + + return ret; +} + diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 4b6cdce..4e3586d 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -23,6 +23,10 @@ cdef extern from "EMD.h": int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, long *iG, long *jG, double *G, long * nG, double* alpha, double* beta, double *cost, int maxIter) + int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y, + long *iD, long *jD, double *D, long nD, + long *iG, long *jG, double *G, long * nG, + double* alpha, double* beta, double *cost, int maxIter) cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 7c6a4ce..498e921 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -686,7 +686,7 @@ namespace lemon { /// \see resetParams(), reset() ProblemType run() { #if DEBUG_LVL>0 - std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n"; + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; #endif if (!init()) return INFEASIBLE; |