diff options
author | arolet <antoine.rolet@gmail.com> | 2017-07-14 15:37:46 +0900 |
---|---|---|
committer | arolet <antoine.rolet@gmail.com> | 2017-07-14 15:37:46 +0900 |
commit | d59e91450272c78dd0fdd3c6bd9bf48776f10070 (patch) | |
tree | c37e400f28fb84307f97662411ec9bb645295430 | |
parent | 0faef7fde7e64705b4f0ed6618a0cfd25319bdc7 (diff) |
Added a test based on closed form solution for gaussians
-rw-r--r-- | test/test_emd.py | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/test/test_emd.py b/test/test_emd.py index 3729d5d..eb1c5c5 100644 --- a/test/test_emd.py +++ b/test/test_emd.py @@ -11,17 +11,25 @@ reload(ot.lp) #%% parameters n=5000 # nb bins +m=6000 # nb bins + +mean1 = 1000 +mean2 = 1100 + +tol = 1e-6 # bin positions x=np.arange(n,dtype=np.float64) +y=np.arange(m,dtype=np.float64) # Gaussian distributions -a=gauss(n,m=20,s=5) # m= mean, s= std +a=gauss(n,m=mean1,s=5) # m= mean, s= std -b=gauss(n,m=30,s=10) +b=gauss(m,m=mean2,s=10) # loss matrix -M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) +M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2) +print M[0,:] #M/=M.max() #%% @@ -30,6 +38,16 @@ print('Computing {} EMD '.format(1)) # emd loss 1 proc ot.tic() -emd_loss4 = ot.emd(a,b,M) +G = ot.emd(a,b,M) ot.toc('1 proc : {} s') +cost1 = (G * M).sum() + +ot.tic() +G = ot.emd(b, a, np.ascontiguousarray(M.T)) +ot.toc('1 proc : {} s') + +cost2 = (G * M.T).sum() + +assert np.abs(cost1-cost2) < tol +assert np.abs(cost1-np.abs(mean1-mean2)) < tol |