summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-03 16:03:13 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-03 16:03:13 +0200
commitf639518e9b96c5904122e62e024ed4ae369ceb33 (patch)
treea67806a07047892e96121e099afde605bc80eb06
parent4fba2c9c479e8f410a23ef24458effc29fc3f7f0 (diff)
add example norm
-rw-r--r--examples/plot_OT_L1_vs_L2.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
new file mode 100644
index 0000000..1ab44a0
--- /dev/null
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+2D Optimal transport between empirical distributions
+====================================================
+
+Stoile the figure idea from:
+https://arxiv.org/pdf/1706.07650.pdf
+
+
+@author: rflamary
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+#%% parameters and data generation
+
+for data in range(2):
+
+ if data:
+ n=20 # nb samples
+ xs=np.zeros((n,2))
+ xs[:,0]=np.arange(n)+1
+ xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
+
+ xt=np.zeros((n,2))
+ xt[:,1]=np.arange(n)+1
+ else:
+
+ n=50 # nb samples
+ xtot=np.zeros((n+1,2))
+ xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
+ xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
+
+ xs=xtot[:n,:]
+ xt=xtot[1:,:]
+
+
+
+ a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
+
+ # loss matrix
+ M1=ot.dist(xs,xt,metric='euclidean')
+ M1/=M1.max()
+
+ # loss matrix
+ M2=ot.dist(xs,xt,metric='sqeuclidean')
+ M2/=M2.max()
+
+ # loss matrix
+ Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
+ Mp/=Mp.max()
+
+ #%% plot samples
+
+ pl.figure(1+3*data)
+ pl.clf()
+ pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
+ pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ pl.axis('equal')
+ pl.title('Source and traget distributions')
+
+ pl.figure(2+3*data,(15,5))
+ pl.subplot(1,3,1)
+ pl.imshow(M1,interpolation='nearest')
+ pl.title('Eucidean cost')
+ pl.subplot(1,3,2)
+ pl.imshow(M2,interpolation='nearest')
+ pl.title('Squared Euclidean cost')
+
+ pl.subplot(1,3,3)
+ pl.imshow(Mp,interpolation='nearest')
+ pl.title('Sqrt Euclidean cost')
+ #%% EMD
+
+ G1=ot.emd(a,b,M1)
+ G2=ot.emd(a,b,M2)
+ Gp=ot.emd(a,b,Mp)
+
+ pl.figure(3+3*data,(15,5))
+
+ pl.subplot(1,3,1)
+ ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
+ pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
+ pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ pl.axis('equal')
+ #pl.legend(loc=0)
+ pl.title('OT Euclidean')
+
+ pl.subplot(1,3,2)
+
+ ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
+ pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
+ pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ pl.axis('equal')
+ #pl.legend(loc=0)
+ pl.title('OT squared Euclidean')
+
+ pl.subplot(1,3,3)
+
+ ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
+ pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
+ pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ pl.axis('equal')
+ #pl.legend(loc=0)
+ pl.title('OT sqrt Euclidean')
+
+#%% sinkhorn