summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorarolet <antoine.rolet@gmail.com>2017-07-14 15:37:46 +0900
committerarolet <antoine.rolet@gmail.com>2017-07-14 15:37:46 +0900
commitd59e91450272c78dd0fdd3c6bd9bf48776f10070 (patch)
treec37e400f28fb84307f97662411ec9bb645295430 /test
parent0faef7fde7e64705b4f0ed6618a0cfd25319bdc7 (diff)
Added a test based on closed form solution for gaussians
Diffstat (limited to 'test')
-rw-r--r--test/test_emd.py26
1 files changed, 22 insertions, 4 deletions
diff --git a/test/test_emd.py b/test/test_emd.py
index 3729d5d..eb1c5c5 100644
--- a/test/test_emd.py
+++ b/test/test_emd.py
@@ -11,17 +11,25 @@ reload(ot.lp)
#%% parameters
n=5000 # nb bins
+m=6000 # nb bins
+
+mean1 = 1000
+mean2 = 1100
+
+tol = 1e-6
# bin positions
x=np.arange(n,dtype=np.float64)
+y=np.arange(m,dtype=np.float64)
# Gaussian distributions
-a=gauss(n,m=20,s=5) # m= mean, s= std
+a=gauss(n,m=mean1,s=5) # m= mean, s= std
-b=gauss(n,m=30,s=10)
+b=gauss(m,m=mean2,s=10)
# loss matrix
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
+M=ot.dist(x.reshape((-1,1)), y.reshape((-1,1))) ** (1./2)
+print M[0,:]
#M/=M.max()
#%%
@@ -30,6 +38,16 @@ print('Computing {} EMD '.format(1))
# emd loss 1 proc
ot.tic()
-emd_loss4 = ot.emd(a,b,M)
+G = ot.emd(a,b,M)
ot.toc('1 proc : {} s')
+cost1 = (G * M).sum()
+
+ot.tic()
+G = ot.emd(b, a, np.ascontiguousarray(M.T))
+ot.toc('1 proc : {} s')
+
+cost2 = (G * M.T).sum()
+
+assert np.abs(cost1-cost2) < tol
+assert np.abs(cost1-np.abs(mean1-mean2)) < tol