summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-08-30 17:01:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-08-30 17:01:01 +0200
commitdc8737a30cb6d9f1305173eb8d16fe6716fd1231 (patch)
tree1f03384de2af88ed07a1e850e0871db826ed53e7 /docs/source/auto_examples/plot_OT_L1_vs_L2.rst
parentc2a7a1f3ab4ba5c4f5adeca0fa22d8d6b4fc079d (diff)
wroking make!
Diffstat (limited to 'docs/source/auto_examples/plot_OT_L1_vs_L2.rst')
-rw-r--r--docs/source/auto_examples/plot_OT_L1_vs_L2.rst151
1 files changed, 79 insertions, 72 deletions
diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.rst b/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
index 4e94bef..ba52bfe 100644
--- a/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
+++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
@@ -7,11 +7,10 @@
2D Optimal transport for different metrics
==========================================
-Stole the figure idea from Fig. 1 and 2 in
+Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
-@author: rflamary
@@ -56,6 +55,10 @@ https://arxiv.org/pdf/1706.07650.pdf
.. code-block:: python
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
@@ -65,94 +68,98 @@ https://arxiv.org/pdf/1706.07650.pdf
for data in range(2):
if data:
- n=20 # nb samples
- xs=np.zeros((n,2))
- xs[:,0]=np.arange(n)+1
- xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
-
- xt=np.zeros((n,2))
- xt[:,1]=np.arange(n)+1
+ n = 20 # nb samples
+ xs = np.zeros((n, 2))
+ xs[:, 0] = np.arange(n) + 1
+ xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...
+
+ xt = np.zeros((n, 2))
+ xt[:, 1] = np.arange(n) + 1
else:
-
- n=50 # nb samples
- xtot=np.zeros((n+1,2))
- xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
- xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
-
- xs=xtot[:n,:]
- xt=xtot[1:,:]
-
-
-
- a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
-
+
+ n = 50 # nb samples
+ xtot = np.zeros((n + 1, 2))
+ xtot[:, 0] = np.cos(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ xtot[:, 1] = np.sin(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+
+ xs = xtot[:n, :]
+ xt = xtot[1:, :]
+
+ a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
# loss matrix
- M1=ot.dist(xs,xt,metric='euclidean')
- M1/=M1.max()
-
+ M1 = ot.dist(xs, xt, metric='euclidean')
+ M1 /= M1.max()
+
# loss matrix
- M2=ot.dist(xs,xt,metric='sqeuclidean')
- M2/=M2.max()
-
+ M2 = ot.dist(xs, xt, metric='sqeuclidean')
+ M2 /= M2.max()
+
# loss matrix
- Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
- Mp/=Mp.max()
-
+ Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+ Mp /= Mp.max()
+
#%% plot samples
-
- pl.figure(1+3*data)
+
+ pl.figure(1 + 3 * data, figsize=(7, 3))
pl.clf()
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
pl.title('Source and traget distributions')
-
- pl.figure(2+3*data,(15,5))
- pl.subplot(1,3,1)
- pl.imshow(M1,interpolation='nearest')
- pl.title('Eucidean cost')
- pl.subplot(1,3,2)
- pl.imshow(M2,interpolation='nearest')
+
+ pl.figure(2 + 3 * data, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ pl.imshow(M1, interpolation='nearest')
+ pl.title('Euclidean cost')
+
+ pl.subplot(1, 3, 2)
+ pl.imshow(M2, interpolation='nearest')
pl.title('Squared Euclidean cost')
-
- pl.subplot(1,3,3)
- pl.imshow(Mp,interpolation='nearest')
+
+ pl.subplot(1, 3, 3)
+ pl.imshow(Mp, interpolation='nearest')
pl.title('Sqrt Euclidean cost')
+ pl.tight_layout()
+
#%% EMD
-
- G1=ot.emd(a,b,M1)
- G2=ot.emd(a,b,M2)
- Gp=ot.emd(a,b,Mp)
-
- pl.figure(3+3*data,(15,5))
-
- pl.subplot(1,3,1)
- ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ G1 = ot.emd(a, b, M1)
+ G2 = ot.emd(a, b, M2)
+ Gp = ot.emd(a, b, Mp)
+
+ pl.figure(3 + 3 * data, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
- #pl.legend(loc=0)
+ # pl.legend(loc=0)
pl.title('OT Euclidean')
-
- pl.subplot(1,3,2)
-
- ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+
+ pl.subplot(1, 3, 2)
+ ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
- #pl.legend(loc=0)
+ # pl.legend(loc=0)
pl.title('OT squared Euclidean')
-
- pl.subplot(1,3,3)
-
- ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+
+ pl.subplot(1, 3, 3)
+ ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
- #pl.legend(loc=0)
+ # pl.legend(loc=0)
pl.title('OT sqrt Euclidean')
+ pl.tight_layout()
+
+ pl.show()
-**Total running time of the script:** ( 0 minutes 1.417 seconds)
+**Total running time of the script:** ( 0 minutes 1.906 seconds)