From c1980a414c879dd1bc1d8904fd43426326741385 Mon Sep 17 00:00:00 2001 From: arolet Date: Fri, 21 Jul 2017 13:34:09 +0900 Subject: Added and passed tests for dual variables --- ot/lp/EMD_wrapper.cpp | 2 +- ot/lp/emd_wrap.pyx | 4 ++-- test/test_emd.py | 28 +++++++++++++++++++--------- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index c6cbb04..0977e75 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -101,7 +101,7 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double flow = net.flow(a); *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); *(G+indI[i]*n2+indJ[j-n]) = flow; - *(alpha + indI[i]) = net.potential(i); + *(alpha + indI[i]) = -net.potential(i); *(beta + indJ[j-n]) = net.potential(j); } diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 813596f..435a270 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -57,7 +57,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod cdef int n1= M.shape[0] cdef int n2= M.shape[1] - cdef float cost=0 + cdef double cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) @@ -116,7 +116,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo cdef int n1= M.shape[0] cdef int n2= M.shape[1] - cdef float cost=0 + cdef double cost=0 cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1]) diff --git a/test/test_emd.py b/test/test_emd.py index 4757cd1..3bf6fa2 100644 --- a/test/test_emd.py +++ b/test/test_emd.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- import numpy as np -import pylab as pl import ot from ot.datasets import get_1D_gauss as gauss @@ -16,8 +15,6 @@ m=6000 # nb bins mean1 = 1000 mean2 = 1100 -tol = 1e-6 - # bin positions x=np.arange(n,dtype=np.float64) y=np.arange(m,dtype=np.float64) @@ -38,10 +35,11 @@ print('Computing {} EMD '.format(1)) # emd loss 1 proc ot.tic() -G = ot.emd(a,b,M) +G, alpha, beta = ot.emd(a,b,M, dual_variables=True) ot.toc('1 proc : {} s') cost1 = (G * M).sum() +cost_dual = np.vdot(a, alpha) + np.vdot(b, beta) # emd loss 1 proc ot.tic() @@ -49,11 +47,23 @@ cost_emd2 = ot.emd2(a,b,M) ot.toc('1 proc : {} s') ot.tic() -G = ot.emd(b, a, np.ascontiguousarray(M.T)) +G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) ot.toc('1 proc : {} s') -cost2 = (G * M.T).sum() +cost2 = (G2 * M.T).sum() + +M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1) + +# Check that both cost computations are equivalent +np.testing.assert_almost_equal(cost1, cost_emd2) +# Check that dual and primal cost are equal +np.testing.assert_almost_equal(cost1, cost_dual) +# Check symmetry +np.testing.assert_almost_equal(cost1, cost2) +# Check with closed-form solution for gaussians +np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2)) + +[ind1, ind2] = np.nonzero(G) -assert np.abs(cost1-cost_emd2)/np.abs(cost1) < tol -assert np.abs(cost1-cost2)/np.abs(cost1) < tol -assert np.abs(cost1-np.abs(mean1-mean2))/np.abs(cost1) < tol +# Check that reduced cost is zero on transport arcs +np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) \ No newline at end of file -- cgit v1.2.3