summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-12-02 11:13:07 +0100
committerRémi Flamary <remi.flamary@gmail.com>2019-12-02 11:13:07 +0100
commit57321bd0172c97b77dfc8b14972c18d063b6dda8 (patch)
tree6d21c211a69cee58b4d62f2abcbd0e99e0f23808
parent4a6883e0ce2fd9f3edd374d54c4c219d876ceb76 (diff)
add awesome sparse solver
-rw-r--r--ot/lp/EMD_wrapper.cpp65
-rw-r--r--ot/lp/emd_wrap.pyx2
-rw-r--r--test/test_ot.py20
3 files changed, 67 insertions, 20 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 3ca7319..2aa44c1 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -111,23 +111,19 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
long *iG, long *jG, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
// beware M and C anre strored in row major C style!!!
- int n, m, i, cur;
+
+ // Get the number of non zero coordinates for r and c and vectors
+ int n, m, i, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
- std::vector<int> indI(n), indJ(m);
- std::vector<double> weights1(n), weights2(m);
- Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
-
- // Get the number of non zero coordinates for r and c and vectors
+ // Get the number of non zero coordinates for r and c
n=0;
for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
- weights1[ n ] = val;
- indI[n++]=i;
+ n++;
}else if(val<0){
return INFEASIBLE;
}
@@ -136,14 +132,42 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
- weights2[ m ] = -val;
- indJ[m++]=i;
+ m++;
}else if(val<0){
return INFEASIBLE;
}
}
// Define the graph
+
+ std::vector<int> indI(n), indJ(m);
+ std::vector<double> weights1(n), weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+ cur=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ weights1[ cur ] = val;
+ indI[cur++]=i;
+ }
+ }
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ weights2[ cur ] = -val;
+ indJ[cur++]=i;
+ }
+ }
+
+ // Define the graph
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
@@ -166,14 +190,17 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
int i = di.source(a);
int j = di.target(a);
double flow = net.flow(a);
- *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
-
- *(G+cur) = flow;
- *(iG+cur) = indI[i];
- *(jG+cur) = indJ[j];
- *(alpha + indI[i]) = -net.potential(i);
- *(beta + indJ[j-n]) = net.potential(j);
- cur++;
+ if (flow>0)
+ {
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+
+ *(G+cur) = flow;
+ *(iG+cur) = indI[i];
+ *(jG+cur) = indJ[j-n];
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
+ cur++;
+ }
}
}
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 345cb66..f183995 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -111,7 +111,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
jG=np.zeros(nmax,dtype=np.int)
- result_code = EMD_wrap_return_sparse(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <long*> iG.data, <long*> jG.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ result_code = EMD_wrap_return_sparse(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <long*> iG.data, <long*> jG.data, <double*> Gv.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
return Gv, iG, jG, cost, alpha, beta, result_code
diff --git a/test/test_ot.py b/test/test_ot.py
index dacae0a..4d59e12 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -118,6 +118,26 @@ def test_emd_empty():
np.testing.assert_allclose(w, 0)
+def test_emd_sparse():
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x2 = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x2)
+
+ G = ot.emd([], [], M)
+
+ Gs = ot.emd([], [], M, sparse=True)
+
+ # check G is the same
+ np.testing.assert_allclose(G, Gs.todense())
+ # check constraints
+
+
def test_emd2_multi():
n = 500 # nb bins