diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-21 11:19:46 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-21 11:19:46 +0200 |
commit | 872e6db7c0d110069b450cbe7efcc186c4871428 (patch) | |
tree | f499a0258cf69f47a211a54447af990ac0afb591 /examples | |
parent | 581c6de782dca279edd97778cc474e7597788c0f (diff) |
demo with sinkhorn
Diffstat (limited to 'examples')
-rw-r--r-- | examples/demo_OT_1D.py | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py index 29f2074..865b178 100644 --- a/examples/demo_OT_1D.py +++ b/examples/demo_OT_1D.py @@ -26,12 +26,8 @@ sb=60 # std of b x=np.arange(n,dtype=np.float64) # Gaussian distributions -a=np.exp(-(x-ma)**2/(2*sa^2)) -b=np.exp(-(x-mb)**2/(2*sb^2)) - -# normalization -a/=a.sum() -b/=b.sum() +a=ot.datasets.get_1D_gauss(n,ma,sa) +b=ot.datasets.get_1D_gauss(n,mb,sb) # loss matrix M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) @@ -99,4 +95,33 @@ pl.imshow(G0,interpolation='nearest') pl.xlim((0,n)) #pl.ylim((0,n)) -#pl.axis('off')
\ No newline at end of file +#pl.axis('off') + +#%% Sinkhorn +lambd=1e3 + +Gs=ot.sinkhorn(a,b,M,lambd) + + +#%% plot Sikhorn optimal tranport matrix +pl.figure(3) +gs = gridspec.GridSpec(3, 3) + +ax1=pl.subplot(gs[0,1:]) +pl.plot(x,b,'r',label='Target distribution') +pl.yticks(()) + +#pl.axis('off') + +ax2=pl.subplot(gs[1:,0]) +pl.plot(a,x,'b',label='Source distribution') +pl.gca().invert_xaxis() +pl.gca().invert_yaxis() +pl.xticks(()) +#pl.ylim((0,n)) +#pl.axis('off') + +pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) +pl.imshow(Gs,interpolation='nearest') + +pl.xlim((0,n))
\ No newline at end of file |