diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-12-02 11:13:07 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2019-12-02 11:13:07 +0100 |
commit | 57321bd0172c97b77dfc8b14972c18d063b6dda8 (patch) | |
tree | 6d21c211a69cee58b4d62f2abcbd0e99e0f23808 | |
parent | 4a6883e0ce2fd9f3edd374d54c4c219d876ceb76 (diff) |
add awesome sparse solver
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 65 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 2 | ||||
-rw-r--r-- | test/test_ot.py | 20 |
3 files changed, 67 insertions, 20 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 3ca7319..2aa44c1 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -111,23 +111,19 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, long *iG, long *jG, double *G, double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! - int n, m, i, cur; + + // Get the number of non zero coordinates for r and c and vectors + int n, m, i, cur; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); - std::vector<int> indI(n), indJ(m); - std::vector<double> weights1(n), weights2(m); - Digraph di(n, m); - NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); - - // Get the number of non zero coordinates for r and c and vectors + // Get the number of non zero coordinates for r and c n=0; for (int i=0; i<n1; i++) { double val=*(X+i); if (val>0) { - weights1[ n ] = val; - indI[n++]=i; + n++; }else if(val<0){ return INFEASIBLE; } @@ -136,14 +132,42 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, for (int i=0; i<n2; i++) { double val=*(Y+i); if (val>0) { - weights2[ m ] = -val; - indJ[m++]=i; + m++; }else if(val<0){ return INFEASIBLE; } } // Define the graph + + std::vector<int> indI(n), indJ(m); + std::vector<double> weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); + + // Set supply and demand, don't account for 0 values (faster) + + cur=0; + for (int i=0; i<n1; i++) { + double val=*(X+i); + if (val>0) { + weights1[ cur ] = val; + indI[cur++]=i; + } + } + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + weights2[ cur ] = -val; + indJ[cur++]=i; + } + } + + // Define the graph net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge @@ -166,14 +190,17 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, int i = di.source(a); int j = di.target(a); double flow = net.flow(a); - *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); - - *(G+cur) = flow; - *(iG+cur) = indI[i]; - *(jG+cur) = indJ[j]; - *(alpha + indI[i]) = -net.potential(i); - *(beta + indJ[j-n]) = net.potential(j); - cur++; + if (flow>0) + { + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + + *(G+cur) = flow; + *(iG+cur) = indI[i]; + *(jG+cur) = indJ[j-n]; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); + cur++; + } } } diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 345cb66..f183995 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -111,7 +111,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod jG=np.zeros(nmax,dtype=np.int) - result_code = EMD_wrap_return_sparse(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <long*> iG.data, <long*> jG.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + result_code = EMD_wrap_return_sparse(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <long*> iG.data, <long*> jG.data, <double*> Gv.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) return Gv, iG, jG, cost, alpha, beta, result_code diff --git a/test/test_ot.py b/test/test_ot.py index dacae0a..4d59e12 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -118,6 +118,26 @@ def test_emd_empty(): np.testing.assert_allclose(w, 0) +def test_emd_sparse(): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x2 = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x2) + + G = ot.emd([], [], M) + + Gs = ot.emd([], [], M, sparse=True) + + # check G is the same + np.testing.assert_allclose(G, Gs.todense()) + # check constraints + + def test_emd2_multi(): n = 500 # nb bins |