diff options
author | arolet <antoine.rolet@gmail.com> | 2017-07-21 12:12:21 +0900 |
---|---|---|
committer | arolet <antoine.rolet@gmail.com> | 2017-07-21 12:12:21 +0900 |
commit | dc3bbd4134f0e2b80e0fe72368bdcf9966f434dc (patch) | |
tree | 00c58d3024e1b40c9d285148e9827d5dced64703 | |
parent | 1fcb7d0ffbc5b00ed20b5ded2e7f1001dc914d6e (diff) |
Cleaned optimal plan and optimal cost computation
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 13 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 5 | ||||
-rw-r--r-- | test/test_emd.py | 10 |
3 files changed, 14 insertions, 14 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index d719c6e..cc13230 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -93,14 +93,13 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, } } else { - for (node_id_type i=0; i<n; i++) - { - for (node_id_type j=0; j<m; j++) - { - *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j)); - } - }; *cost = net.totalCost(); + Arc a; di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a); + *(G+indI[i]*n2+indJ[j-n]) = net.flow(a); + } }; diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index e8fdba4..c4ba125 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -121,11 +121,6 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo # calling the function EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, maxiter) - - cost=0 - for i in range(n1): - for j in range(n2): - cost+=G[i,j]*M[i,j] return cost diff --git a/test/test_emd.py b/test/test_emd.py index eb1c5c5..4757cd1 100644 --- a/test/test_emd.py +++ b/test/test_emd.py @@ -43,11 +43,17 @@ ot.toc('1 proc : {} s') cost1 = (G * M).sum() +# emd loss 1 proc +ot.tic() +cost_emd2 = ot.emd2(a,b,M) +ot.toc('1 proc : {} s') + ot.tic() G = ot.emd(b, a, np.ascontiguousarray(M.T)) ot.toc('1 proc : {} s') cost2 = (G * M.T).sum() -assert np.abs(cost1-cost2) < tol -assert np.abs(cost1-np.abs(mean1-mean2)) < tol +assert np.abs(cost1-cost_emd2)/np.abs(cost1) < tol +assert np.abs(cost1-cost2)/np.abs(cost1) < tol +assert np.abs(cost1-np.abs(mean1-mean2))/np.abs(cost1) < tol |