diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 10:28:21 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 10:28:21 +0200 |
commit | b9ae39adae1d1a16f6bfb79e73c0d67d3157e1de (patch) | |
tree | eee92564b9c7b02eca07b7a5a6e741245e70e5e5 /examples/demo_OT_1D.py | |
parent | 89999e07871f82bd1738750b5bbcd97fae9ea4b2 (diff) |
update demo and add plot module
Diffstat (limited to 'examples/demo_OT_1D.py')
-rw-r--r-- | examples/demo_OT_1D.py | 49 |
1 files changed, 6 insertions, 43 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py index 8ba40ba..b17f902 100644 --- a/examples/demo_OT_1D.py +++ b/examples/demo_OT_1D.py @@ -7,7 +7,7 @@ Created on Fri Oct 21 09:51:45 2016 import numpy as np import matplotlib.pylab as pl -from matplotlib import gridspec + import ot @@ -16,18 +16,12 @@ import ot n=100 # nb bins -ma=20 # mean of a -mb=60 # mean of b - -sa=20 # std of a -sb=60 # std of b - # bin positions x=np.arange(n,dtype=np.float64) # Gaussian distributions -a=ot.datasets.get_1D_gauss(n,ma,sa) -b=ot.datasets.get_1D_gauss(n,mb,sb) +a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std +b=ot.datasets.get_1D_gauss(n,m=60,s=60) # loss matrix M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) @@ -44,47 +38,16 @@ pl.legend() #%% plot distributions and loss matrix -def plotmat(M,title=''): - """ Plot a matrix woth the 1D distribution """ - gs = gridspec.GridSpec(3, 3) - - ax1=pl.subplot(gs[0,1:]) - pl.plot(x,b,'r',label='Target distribution') - pl.yticks(()) - pl.title(title) - - #pl.axis('off') - - ax2=pl.subplot(gs[1:,0]) - pl.plot(a,x,'b',label='Source distribution') - pl.gca().invert_xaxis() - pl.gca().invert_yaxis() - pl.xticks(()) - #pl.ylim((0,n)) - #pl.axis('off') - - pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) - pl.imshow(M,interpolation='nearest') - - pl.xlim((0,n)) - pl.figure(2) +ot.plot.otplot1D(a,b,M,'Cost matrix M') -plotmat(M,'Cost matrix M') - - - - - -#pl.ylim((0,n)) -#pl.axis('off') #%% EMD G0=ot.emd(a,b,M) pl.figure(3) -plotmat(G0,'OT matrix G0') +ot.plot.otplot1D(a,b,G0,'OT matrix G0') #%% Sinkhorn lambd=1e-3 @@ -92,4 +55,4 @@ lambd=1e-3 Gs=ot.sinkhorn(a,b,M,lambd) pl.figure(4) -plotmat(Gs,'OT matrix Sinkhorn') +ot.plot.otplot1D(a,b,Gs,'OT matrix Sinkhorn') |