diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 10:28:21 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 10:28:21 +0200 |
commit | b9ae39adae1d1a16f6bfb79e73c0d67d3157e1de (patch) | |
tree | eee92564b9c7b02eca07b7a5a6e741245e70e5e5 | |
parent | 89999e07871f82bd1738750b5bbcd97fae9ea4b2 (diff) |
update demo and add plot module
-rw-r--r-- | examples/demo_OT_1D.py | 49 | ||||
-rw-r--r-- | ot/__init__.py | 2 | ||||
-rw-r--r-- | ot/plot.py | 40 |
3 files changed, 48 insertions, 43 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py index 8ba40ba..b17f902 100644 --- a/examples/demo_OT_1D.py +++ b/examples/demo_OT_1D.py @@ -7,7 +7,7 @@ Created on Fri Oct 21 09:51:45 2016 import numpy as np import matplotlib.pylab as pl -from matplotlib import gridspec + import ot @@ -16,18 +16,12 @@ import ot n=100 # nb bins -ma=20 # mean of a -mb=60 # mean of b - -sa=20 # std of a -sb=60 # std of b - # bin positions x=np.arange(n,dtype=np.float64) # Gaussian distributions -a=ot.datasets.get_1D_gauss(n,ma,sa) -b=ot.datasets.get_1D_gauss(n,mb,sb) +a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std +b=ot.datasets.get_1D_gauss(n,m=60,s=60) # loss matrix M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) @@ -44,47 +38,16 @@ pl.legend() #%% plot distributions and loss matrix -def plotmat(M,title=''): - """ Plot a matrix woth the 1D distribution """ - gs = gridspec.GridSpec(3, 3) - - ax1=pl.subplot(gs[0,1:]) - pl.plot(x,b,'r',label='Target distribution') - pl.yticks(()) - pl.title(title) - - #pl.axis('off') - - ax2=pl.subplot(gs[1:,0]) - pl.plot(a,x,'b',label='Source distribution') - pl.gca().invert_xaxis() - pl.gca().invert_yaxis() - pl.xticks(()) - #pl.ylim((0,n)) - #pl.axis('off') - - pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) - pl.imshow(M,interpolation='nearest') - - pl.xlim((0,n)) - pl.figure(2) +ot.plot.otplot1D(a,b,M,'Cost matrix M') -plotmat(M,'Cost matrix M') - - - - - -#pl.ylim((0,n)) -#pl.axis('off') #%% EMD G0=ot.emd(a,b,M) pl.figure(3) -plotmat(G0,'OT matrix G0') +ot.plot.otplot1D(a,b,G0,'OT matrix G0') #%% Sinkhorn lambd=1e-3 @@ -92,4 +55,4 @@ lambd=1e-3 Gs=ot.sinkhorn(a,b,M,lambd) pl.figure(4) -plotmat(Gs,'OT matrix Sinkhorn') +ot.plot.otplot1D(a,b,Gs,'OT matrix Sinkhorn') diff --git a/ot/__init__.py b/ot/__init__.py index 14c6181..fe55771 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -2,6 +2,8 @@ # utils submodules import utils import datasets +import plot + # Ot functions from emd import emd diff --git a/ot/plot.py b/ot/plot.py new file mode 100644 index 0000000..b8357b1 --- /dev/null +++ b/ot/plot.py @@ -0,0 +1,40 @@ + +import numpy as np +import matplotlib.pylab as pl +from matplotlib import gridspec + + +def otplot1D(a,b,M,title=''): + """ Plot a with the source and target 1D distribution """ + + na=M.shape[0] + nb=M.shape[1] + + gs = gridspec.GridSpec(3, 3) + + + xa=np.arange(na) + xb=np.arange(nb) + + + ax1=pl.subplot(gs[0,1:]) + pl.plot(xb,b,'r',label='Target distribution') + pl.yticks(()) + pl.title(title) + + #pl.axis('off') + + ax2=pl.subplot(gs[1:,0]) + pl.plot(a,xa,'b',label='Source distribution') + pl.gca().invert_xaxis() + pl.gca().invert_yaxis() + pl.xticks(()) + #pl.ylim((0,n)) + #pl.axis('off') + + pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) + pl.imshow(M,interpolation='nearest') + + pl.xlim((0,nb)) + + |