summaryrefslogtreecommitdiff
path: root/ot/lp/EMD_wrapper.cpp
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-12-03 15:20:16 +0100
committerRémi Flamary <remi.flamary@gmail.com>2019-12-03 15:20:16 +0100
commita4afee871d8e9d5db68228d1ed5bf4853eedc294 (patch)
tree9aed5291d5e472085b0eca3d0453865684bd443d /ot/lp/EMD_wrapper.cpp
parentc439e3efb920086154c741b41f65d99165e875d8 (diff)
first implemntation sparse loss
Diffstat (limited to 'ot/lp/EMD_wrapper.cpp')
-rw-r--r--ot/lp/EMD_wrapper.cpp78
1 files changed, 78 insertions, 0 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 9be2cdc..28e4af2 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -210,3 +210,81 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
return ret;
}
+int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
+ long *iD, long *jD, double *D, long nD,
+ long *iG, long *jG, double *G, long * nG,
+ double* alpha, double* beta, double *cost, int maxIter) {
+ // beware M and C anre strored in row major C style!!!
+
+ // Get the number of non zero coordinates for r and c and vectors
+ int n, m, cur;
+
+ typedef FullBipartiteDigraph Digraph;
+ DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
+
+ n=n1;
+ m=n2;
+
+
+ // Define the graph
+
+
+ std::vector<double> weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ weights2[ cur ] = -val;
+ }
+ }
+
+ // Define the graph
+ net.supplyMap(X, n, &weights2[0], m);
+
+ // Set the cost of each edge
+ for (int k=0; k<nD; k++) {
+ int i = iD[k];
+ int j = jD[k];
+ net.setCost(di.arcFromId(i*m+j), D[k]);
+
+ }
+
+
+ // Solve the problem with the network simplex algorithm
+
+ int ret=net.run();
+ if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
+ *cost = net.totalCost();
+ Arc a; di.first(a);
+ cur=0;
+ for (; a != INVALID; di.next(a)) {
+ int i = di.source(a);
+ int j = di.target(a);
+ double flow = net.flow(a);
+ if (flow>0)
+ {
+
+ *(G+cur) = flow;
+ *(iG+cur) = i;
+ *(jG+cur) = j-n;
+ *(alpha + i) = -net.potential(i);
+ *(beta + j-n) = net.potential(j);
+ cur++;
+ }
+ }
+ *nG=cur; // nb of value +1 for numpy indexing
+
+ }
+
+
+ return ret;
+}
+