summaryrefslogtreecommitdiff
path: root/examples/plot_OT_L1_vs_L2.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-11 21:51:29 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:05:12 +0200
commitc6cb1cd666a3e1b761b83a6e0f9339268e69f099 (patch)
treeed4b95b6d722eb74162b259c2007c13030b484ae /examples/plot_OT_L1_vs_L2.py
parent35b25adf1fd9ec35fb6f6da105e268ed5b467d79 (diff)
pimp + pep8 on plot_OT_L1_vs_L2
Diffstat (limited to 'examples/plot_OT_L1_vs_L2.py')
-rw-r--r--examples/plot_OT_L1_vs_L2.py168
1 files changed, 86 insertions, 82 deletions
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index 9bb92fe..e11d6ad 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -4,7 +4,7 @@
2D Optimal transport for different metrics
==========================================
-Stole the figure idea from Fig. 1 and 2 in
+Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
@@ -12,7 +12,7 @@ https://arxiv.org/pdf/1706.07650.pdf
"""
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pylab as plt
import ot
#%% parameters and data generation
@@ -20,89 +20,93 @@ import ot
for data in range(2):
if data:
- n=20 # nb samples
- xs=np.zeros((n,2))
- xs[:,0]=np.arange(n)+1
- xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
-
- xt=np.zeros((n,2))
- xt[:,1]=np.arange(n)+1
+ n = 20 # nb samples
+ xs = np.zeros((n, 2))
+ xs[:, 0] = np.arange(n) + 1
+ xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...
+
+ xt = np.zeros((n, 2))
+ xt[:, 1] = np.arange(n) + 1
else:
-
- n=50 # nb samples
- xtot=np.zeros((n+1,2))
- xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
- xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
-
- xs=xtot[:n,:]
- xt=xtot[1:,:]
-
-
-
- a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
-
+
+ n = 50 # nb samples
+ xtot = np.zeros((n + 1, 2))
+ xtot[:, 0] = np.cos(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ xtot[:, 1] = np.sin(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+
+ xs = xtot[:n, :]
+ xt = xtot[1:, :]
+
+ a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
# loss matrix
- M1=ot.dist(xs,xt,metric='euclidean')
- M1/=M1.max()
-
+ M1 = ot.dist(xs, xt, metric='euclidean')
+ M1 /= M1.max()
+
# loss matrix
- M2=ot.dist(xs,xt,metric='sqeuclidean')
- M2/=M2.max()
-
+ M2 = ot.dist(xs, xt, metric='sqeuclidean')
+ M2 /= M2.max()
+
# loss matrix
- Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
- Mp/=Mp.max()
-
+ Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+ Mp /= Mp.max()
+
#%% plot samples
-
- pl.figure(1+3*data)
- pl.clf()
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- pl.title('Source and traget distributions')
-
- pl.figure(2+3*data,(15,5))
- pl.subplot(1,3,1)
- pl.imshow(M1,interpolation='nearest')
- pl.title('Eucidean cost')
- pl.subplot(1,3,2)
- pl.imshow(M2,interpolation='nearest')
- pl.title('Squared Euclidean cost')
-
- pl.subplot(1,3,3)
- pl.imshow(Mp,interpolation='nearest')
- pl.title('Sqrt Euclidean cost')
+
+ plt.figure(1 + 3 * data, figsize=(7, 3))
+ plt.clf()
+ plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ plt.axis('equal')
+ plt.title('Source and traget distributions')
+
+ plt.figure(2 + 3 * data, figsize=(7, 3))
+
+ plt.subplot(1, 3, 1)
+ plt.imshow(M1, interpolation='nearest')
+ plt.title('Euclidean cost')
+
+ plt.subplot(1, 3, 2)
+ plt.imshow(M2, interpolation='nearest')
+ plt.title('Squared Euclidean cost')
+
+ plt.subplot(1, 3, 3)
+ plt.imshow(Mp, interpolation='nearest')
+ plt.title('Sqrt Euclidean cost')
+ plt.tight_layout()
+
#%% EMD
-
- G1=ot.emd(a,b,M1)
- G2=ot.emd(a,b,M2)
- Gp=ot.emd(a,b,Mp)
-
- pl.figure(3+3*data,(15,5))
-
- pl.subplot(1,3,1)
- ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT Euclidean')
-
- pl.subplot(1,3,2)
-
- ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT squared Euclidean')
-
- pl.subplot(1,3,3)
-
- ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT sqrt Euclidean')
+ G1 = ot.emd(a, b, M1)
+ G2 = ot.emd(a, b, M2)
+ Gp = ot.emd(a, b, Mp)
+
+ plt.figure(3 + 3 * data, figsize=(7, 3))
+
+ plt.subplot(1, 3, 1)
+ ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+ plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ plt.axis('equal')
+ # plt.legend(loc=0)
+ plt.title('OT Euclidean')
+
+ plt.subplot(1, 3, 2)
+ ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+ plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ plt.axis('equal')
+ # plt.legend(loc=0)
+ plt.title('OT squared Euclidean')
+
+ plt.subplot(1, 3, 3)
+ ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+ plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ plt.axis('equal')
+ # plt.legend(loc=0)
+ plt.title('OT sqrt Euclidean')
+ plt.tight_layout()
+
+plt.show()