summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-03-14 10:30:45 +0100
committerRémi Flamary <remi.flamary@gmail.com>2017-03-14 10:30:45 +0100
commita84f2c3e23edd1fa89975bd77b08672f518d5ca4 (patch)
tree1aadf05357949e6daec2c332eb900e93346ad465 /test
parent84219d9bd87acd9bbb6d1a832cf4ccaee53fed0b (diff)
add emd2+ multiproc
Diffstat (limited to 'test')
-rw-r--r--test/test_emd_multi.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/test/test_emd_multi.py b/test/test_emd_multi.py
new file mode 100644
index 0000000..ee0a20e
--- /dev/null
+++ b/test/test_emd_multi.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python2
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Mar 10 09:56:06 2017
+
+@author: rflamary
+"""
+
+import numpy as np
+import pylab as pl
+import ot
+
+from ot.datasets import get_1D_gauss as gauss
+reload(ot.lp)
+
+#%% parameters
+
+n=5000 # nb bins
+
+# bin positions
+x=np.arange(n,dtype=np.float64)
+
+# Gaussian distributions
+a=gauss(n,m=20,s=5) # m= mean, s= std
+
+ls= range(20,1000,10)
+nb=len(ls)
+b=np.zeros((n,nb))
+for i in range(nb):
+ b[:,i]=gauss(n,m=ls[i],s=10)
+
+# loss matrix
+M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
+#M/=M.max()
+
+#%%
+
+print('Computing {} EMD '.format(nb))
+
+# emd loss 1 proc
+ot.tic()
+emd_loss4=ot.emd2(a,b,M,1)
+ot.toc('1 proc : {} s')
+
+# emd loss multipro proc
+ot.tic()
+emd_loss4=ot.emd2(a,b,M)
+ot.toc('multi proc : {} s')