summaryrefslogtreecommitdiff
path: root/examples/demo_OT_1D.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/demo_OT_1D.py')
-rw-r--r--examples/demo_OT_1D.py49
1 files changed, 6 insertions, 43 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py
index 8ba40ba..b17f902 100644
--- a/examples/demo_OT_1D.py
+++ b/examples/demo_OT_1D.py
@@ -7,7 +7,7 @@ Created on Fri Oct 21 09:51:45 2016
import numpy as np
import matplotlib.pylab as pl
-from matplotlib import gridspec
+
import ot
@@ -16,18 +16,12 @@ import ot
n=100 # nb bins
-ma=20 # mean of a
-mb=60 # mean of b
-
-sa=20 # std of a
-sb=60 # std of b
-
# bin positions
x=np.arange(n,dtype=np.float64)
# Gaussian distributions
-a=ot.datasets.get_1D_gauss(n,ma,sa)
-b=ot.datasets.get_1D_gauss(n,mb,sb)
+a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std
+b=ot.datasets.get_1D_gauss(n,m=60,s=60)
# loss matrix
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
@@ -44,47 +38,16 @@ pl.legend()
#%% plot distributions and loss matrix
-def plotmat(M,title=''):
- """ Plot a matrix woth the 1D distribution """
- gs = gridspec.GridSpec(3, 3)
-
- ax1=pl.subplot(gs[0,1:])
- pl.plot(x,b,'r',label='Target distribution')
- pl.yticks(())
- pl.title(title)
-
- #pl.axis('off')
-
- ax2=pl.subplot(gs[1:,0])
- pl.plot(a,x,'b',label='Source distribution')
- pl.gca().invert_xaxis()
- pl.gca().invert_yaxis()
- pl.xticks(())
- #pl.ylim((0,n))
- #pl.axis('off')
-
- pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
- pl.imshow(M,interpolation='nearest')
-
- pl.xlim((0,n))
-
pl.figure(2)
+ot.plot.otplot1D(a,b,M,'Cost matrix M')
-plotmat(M,'Cost matrix M')
-
-
-
-
-
-#pl.ylim((0,n))
-#pl.axis('off')
#%% EMD
G0=ot.emd(a,b,M)
pl.figure(3)
-plotmat(G0,'OT matrix G0')
+ot.plot.otplot1D(a,b,G0,'OT matrix G0')
#%% Sinkhorn
lambd=1e-3
@@ -92,4 +55,4 @@ lambd=1e-3
Gs=ot.sinkhorn(a,b,M,lambd)
pl.figure(4)
-plotmat(Gs,'OT matrix Sinkhorn')
+ot.plot.otplot1D(a,b,Gs,'OT matrix Sinkhorn')