summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorarolet <antoine.rolet@gmail.com>2017-07-21 12:12:21 +0900
committerarolet <antoine.rolet@gmail.com>2017-07-21 12:12:21 +0900
commitdc3bbd4134f0e2b80e0fe72368bdcf9966f434dc (patch)
tree00c58d3024e1b40c9d285148e9827d5dced64703
parent1fcb7d0ffbc5b00ed20b5ded2e7f1001dc914d6e (diff)
Cleaned optimal plan and optimal cost computation
-rw-r--r--ot/lp/EMD_wrapper.cpp13
-rw-r--r--ot/lp/emd_wrap.pyx5
-rw-r--r--test/test_emd.py10
3 files changed, 14 insertions, 14 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index d719c6e..cc13230 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -93,14 +93,13 @@ void EMD_wrap(int n1, int n2, double *X, double *Y,
}
} else
{
- for (node_id_type i=0; i<n; i++)
- {
- for (node_id_type j=0; j<m; j++)
- {
- *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j));
- }
- };
*cost = net.totalCost();
+ Arc a; di.first(a);
+ for (; a != INVALID; di.next(a)) {
+ int i = di.source(a);
+ int j = di.target(a);
+ *(G+indI[i]*n2+indJ[j-n]) = net.flow(a);
+ }
};
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index e8fdba4..c4ba125 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -121,11 +121,6 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
# calling the function
EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, maxiter)
-
- cost=0
- for i in range(n1):
- for j in range(n2):
- cost+=G[i,j]*M[i,j]
return cost
diff --git a/test/test_emd.py b/test/test_emd.py
index eb1c5c5..4757cd1 100644
--- a/test/test_emd.py
+++ b/test/test_emd.py
@@ -43,11 +43,17 @@ ot.toc('1 proc : {} s')
cost1 = (G * M).sum()
+# emd loss 1 proc
+ot.tic()
+cost_emd2 = ot.emd2(a,b,M)
+ot.toc('1 proc : {} s')
+
ot.tic()
G = ot.emd(b, a, np.ascontiguousarray(M.T))
ot.toc('1 proc : {} s')
cost2 = (G * M.T).sum()
-assert np.abs(cost1-cost2) < tol
-assert np.abs(cost1-np.abs(mean1-mean2)) < tol
+assert np.abs(cost1-cost_emd2)/np.abs(cost1) < tol
+assert np.abs(cost1-cost2)/np.abs(cost1) < tol
+assert np.abs(cost1-np.abs(mean1-mean2))/np.abs(cost1) < tol