diff options
Diffstat (limited to 'test/test_emd.py')
-rw-r--r-- | test/test_emd.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/test/test_emd.py b/test/test_emd.py index eb1c5c5..4757cd1 100644 --- a/test/test_emd.py +++ b/test/test_emd.py @@ -43,11 +43,17 @@ ot.toc('1 proc : {} s') cost1 = (G * M).sum() +# emd loss 1 proc +ot.tic() +cost_emd2 = ot.emd2(a,b,M) +ot.toc('1 proc : {} s') + ot.tic() G = ot.emd(b, a, np.ascontiguousarray(M.T)) ot.toc('1 proc : {} s') cost2 = (G * M.T).sum() -assert np.abs(cost1-cost2) < tol -assert np.abs(cost1-np.abs(mean1-mean2)) < tol +assert np.abs(cost1-cost_emd2)/np.abs(cost1) < tol +assert np.abs(cost1-cost2)/np.abs(cost1) < tol +assert np.abs(cost1-np.abs(mean1-mean2))/np.abs(cost1) < tol |