diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-03 16:03:13 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-03 16:03:13 +0200 |
commit | f639518e9b96c5904122e62e024ed4ae369ceb33 (patch) | |
tree | a67806a07047892e96121e099afde605bc80eb06 /examples | |
parent | 4fba2c9c479e8f410a23ef24458effc29fc3f7f0 (diff) |
add example norm
Diffstat (limited to 'examples')
-rw-r--r-- | examples/plot_OT_L1_vs_L2.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py new file mode 100644 index 0000000..1ab44a0 --- /dev/null +++ b/examples/plot_OT_L1_vs_L2.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +2D Optimal transport between empirical distributions +==================================================== + +Stoile the figure idea from: +https://arxiv.org/pdf/1706.07650.pdf + + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + +#%% parameters and data generation + +for data in range(2): + + if data: + n=20 # nb samples + xs=np.zeros((n,2)) + xs[:,0]=np.arange(n)+1 + xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex... + + xt=np.zeros((n,2)) + xt[:,1]=np.arange(n)+1 + else: + + n=50 # nb samples + xtot=np.zeros((n+1,2)) + xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi) + xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi) + + xs=xtot[:n,:] + xt=xtot[1:,:] + + + + a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples + + # loss matrix + M1=ot.dist(xs,xt,metric='euclidean') + M1/=M1.max() + + # loss matrix + M2=ot.dist(xs,xt,metric='sqeuclidean') + M2/=M2.max() + + # loss matrix + Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean')) + Mp/=Mp.max() + + #%% plot samples + + pl.figure(1+3*data) + pl.clf() + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + pl.title('Source and traget distributions') + + pl.figure(2+3*data,(15,5)) + pl.subplot(1,3,1) + pl.imshow(M1,interpolation='nearest') + pl.title('Eucidean cost') + pl.subplot(1,3,2) + pl.imshow(M2,interpolation='nearest') + pl.title('Squared Euclidean cost') + + pl.subplot(1,3,3) + pl.imshow(Mp,interpolation='nearest') + pl.title('Sqrt Euclidean cost') + #%% EMD + + G1=ot.emd(a,b,M1) + G2=ot.emd(a,b,M2) + Gp=ot.emd(a,b,Mp) + + pl.figure(3+3*data,(15,5)) + + pl.subplot(1,3,1) + ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT Euclidean') + + pl.subplot(1,3,2) + + ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT squared Euclidean') + + pl.subplot(1,3,3) + + ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT sqrt Euclidean') + +#%% sinkhorn |