From 872e6db7c0d110069b450cbe7efcc186c4871428 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 21 Oct 2016 11:19:46 +0200 Subject: demo with sinkhorn --- examples/demo_OT_1D.py | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) (limited to 'examples/demo_OT_1D.py') diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py index 29f2074..865b178 100644 --- a/examples/demo_OT_1D.py +++ b/examples/demo_OT_1D.py @@ -26,12 +26,8 @@ sb=60 # std of b x=np.arange(n,dtype=np.float64) # Gaussian distributions -a=np.exp(-(x-ma)**2/(2*sa^2)) -b=np.exp(-(x-mb)**2/(2*sb^2)) - -# normalization -a/=a.sum() -b/=b.sum() +a=ot.datasets.get_1D_gauss(n,ma,sa) +b=ot.datasets.get_1D_gauss(n,mb,sb) # loss matrix M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) @@ -99,4 +95,33 @@ pl.imshow(G0,interpolation='nearest') pl.xlim((0,n)) #pl.ylim((0,n)) -#pl.axis('off') \ No newline at end of file +#pl.axis('off') + +#%% Sinkhorn +lambd=1e3 + +Gs=ot.sinkhorn(a,b,M,lambd) + + +#%% plot Sikhorn optimal tranport matrix +pl.figure(3) +gs = gridspec.GridSpec(3, 3) + +ax1=pl.subplot(gs[0,1:]) +pl.plot(x,b,'r',label='Target distribution') +pl.yticks(()) + +#pl.axis('off') + +ax2=pl.subplot(gs[1:,0]) +pl.plot(a,x,'b',label='Source distribution') +pl.gca().invert_xaxis() +pl.gca().invert_yaxis() +pl.xticks(()) +#pl.ylim((0,n)) +#pl.axis('off') + +pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) +pl.imshow(Gs,interpolation='nearest') + +pl.xlim((0,n)) \ No newline at end of file -- cgit v1.2.3