summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-12-03 15:20:16 +0100
committerRémi Flamary <remi.flamary@gmail.com>2019-12-03 15:20:16 +0100
commita4afee871d8e9d5db68228d1ed5bf4853eedc294 (patch)
tree9aed5291d5e472085b0eca3d0453865684bd443d /ot/lp
parentc439e3efb920086154c741b41f65d99165e875d8 (diff)
first implemntation sparse loss
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/EMD.h5
-rw-r--r--ot/lp/EMD_wrapper.cpp78
-rw-r--r--ot/lp/emd_wrap.pyx4
-rw-r--r--ot/lp/network_simplex_simple.h2
4 files changed, 88 insertions, 1 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index 9896091..fc94211 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -36,4 +36,9 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
long *iG, long *jG, double *G, long * nG,
double* alpha, double* beta, double *cost, int maxIter);
+int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
+ long *iD, long *jD, double *D, long nD,
+ long *iG, long *jG, double *G, long * nG,
+ double* alpha, double* beta, double *cost, int maxIter);
+
#endif
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 9be2cdc..28e4af2 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -210,3 +210,81 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
return ret;
}
+int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
+ long *iD, long *jD, double *D, long nD,
+ long *iG, long *jG, double *G, long * nG,
+ double* alpha, double* beta, double *cost, int maxIter) {
+ // beware M and C anre strored in row major C style!!!
+
+ // Get the number of non zero coordinates for r and c and vectors
+ int n, m, cur;
+
+ typedef FullBipartiteDigraph Digraph;
+ DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
+
+ n=n1;
+ m=n2;
+
+
+ // Define the graph
+
+
+ std::vector<double> weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ weights2[ cur ] = -val;
+ }
+ }
+
+ // Define the graph
+ net.supplyMap(X, n, &weights2[0], m);
+
+ // Set the cost of each edge
+ for (int k=0; k<nD; k++) {
+ int i = iD[k];
+ int j = jD[k];
+ net.setCost(di.arcFromId(i*m+j), D[k]);
+
+ }
+
+
+ // Solve the problem with the network simplex algorithm
+
+ int ret=net.run();
+ if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
+ *cost = net.totalCost();
+ Arc a; di.first(a);
+ cur=0;
+ for (; a != INVALID; di.next(a)) {
+ int i = di.source(a);
+ int j = di.target(a);
+ double flow = net.flow(a);
+ if (flow>0)
+ {
+
+ *(G+cur) = flow;
+ *(iG+cur) = i;
+ *(jG+cur) = j-n;
+ *(alpha + i) = -net.potential(i);
+ *(beta + j-n) = net.potential(j);
+ cur++;
+ }
+ }
+ *nG=cur; // nb of value +1 for numpy indexing
+
+ }
+
+
+ return ret;
+}
+
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 4b6cdce..4e3586d 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -23,6 +23,10 @@ cdef extern from "EMD.h":
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
long *iG, long *jG, double *G, long * nG,
double* alpha, double* beta, double *cost, int maxIter)
+ int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
+ long *iD, long *jD, double *D, long nD,
+ long *iG, long *jG, double *G, long * nG,
+ double* alpha, double* beta, double *cost, int maxIter)
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 7c6a4ce..498e921 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -686,7 +686,7 @@ namespace lemon {
/// \see resetParams(), reset()
ProblemType run() {
#if DEBUG_LVL>0
- std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n";
+ std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ;
#endif
if (!init()) return INFEASIBLE;