diff options
Diffstat (limited to 'examples/demo_OT_1D.py')
-rw-r--r-- | examples/demo_OT_1D.py | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py index 29f2074..865b178 100644 --- a/examples/demo_OT_1D.py +++ b/examples/demo_OT_1D.py @@ -26,12 +26,8 @@ sb=60 # std of b x=np.arange(n,dtype=np.float64) # Gaussian distributions -a=np.exp(-(x-ma)**2/(2*sa^2)) -b=np.exp(-(x-mb)**2/(2*sb^2)) - -# normalization -a/=a.sum() -b/=b.sum() +a=ot.datasets.get_1D_gauss(n,ma,sa) +b=ot.datasets.get_1D_gauss(n,mb,sb) # loss matrix M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) @@ -99,4 +95,33 @@ pl.imshow(G0,interpolation='nearest') pl.xlim((0,n)) #pl.ylim((0,n)) -#pl.axis('off')
\ No newline at end of file +#pl.axis('off') + +#%% Sinkhorn +lambd=1e3 + +Gs=ot.sinkhorn(a,b,M,lambd) + + +#%% plot Sikhorn optimal tranport matrix +pl.figure(3) +gs = gridspec.GridSpec(3, 3) + +ax1=pl.subplot(gs[0,1:]) +pl.plot(x,b,'r',label='Target distribution') +pl.yticks(()) + +#pl.axis('off') + +ax2=pl.subplot(gs[1:,0]) +pl.plot(a,x,'b',label='Source distribution') +pl.gca().invert_xaxis() +pl.gca().invert_yaxis() +pl.xticks(()) +#pl.ylim((0,n)) +#pl.axis('off') + +pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) +pl.imshow(Gs,interpolation='nearest') + +pl.xlim((0,n))
\ No newline at end of file |