summaryrefslogtreecommitdiff
path: root/examples/plot_OT_1D.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_OT_1D.py')
-rw-r--r--examples/plot_OT_1D.py44
1 files changed, 24 insertions, 20 deletions
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index 6661aa3..0f3a26a 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -4,53 +4,57 @@
1D optimal transport
====================
-@author: rflamary
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
from ot.datasets import get_1D_gauss as gauss
-
#%% parameters
-n=100 # nb bins
+n = 100 # nb bins
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
# Gaussian distributions
-a=gauss(n,m=20,s=5) # m= mean, s= std
-b=gauss(n,m=60,s=10)
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
# loss matrix
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
-M/=M.max()
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
#%% plot the distributions
-pl.figure(1)
-pl.plot(x,a,'b',label='Source distribution')
-pl.plot(x,b,'r',label='Target distribution')
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
pl.legend()
#%% plot distributions and loss matrix
-pl.figure(2)
-ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
#%% EMD
-G0=ot.emd(a,b,M)
+G0 = ot.emd(a, b, M)
-pl.figure(3)
-ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
#%% Sinkhorn
-lambd=1e-3
-Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
+lambd = 1e-3
+Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
-pl.figure(4)
-ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
+pl.show()