diff options
Diffstat (limited to 'examples/demo_OT_1D.py')
-rw-r--r-- | examples/demo_OT_1D.py | 96 |
1 files changed, 32 insertions, 64 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py index 865b178..8ba40ba 100644 --- a/examples/demo_OT_1D.py +++ b/examples/demo_OT_1D.py @@ -44,27 +44,38 @@ pl.legend() #%% plot distributions and loss matrix -pl.figure(2) -gs = gridspec.GridSpec(3, 3) - -ax1=pl.subplot(gs[0,1:]) -pl.plot(x,b,'r',label='Target distribution') -pl.yticks(()) +def plotmat(M,title=''): + """ Plot a matrix woth the 1D distribution """ + gs = gridspec.GridSpec(3, 3) + + ax1=pl.subplot(gs[0,1:]) + pl.plot(x,b,'r',label='Target distribution') + pl.yticks(()) + pl.title(title) + + #pl.axis('off') + + ax2=pl.subplot(gs[1:,0]) + pl.plot(a,x,'b',label='Source distribution') + pl.gca().invert_xaxis() + pl.gca().invert_yaxis() + pl.xticks(()) + #pl.ylim((0,n)) + #pl.axis('off') + + pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) + pl.imshow(M,interpolation='nearest') + + pl.xlim((0,n)) -#pl.axis('off') +pl.figure(2) -ax2=pl.subplot(gs[1:,0]) -pl.plot(a,x,'b',label='Source distribution') -pl.gca().invert_xaxis() -pl.gca().invert_yaxis() -pl.xticks(()) -#pl.ylim((0,n)) -#pl.axis('off') +plotmat(M,'Cost matrix M') -pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) -pl.imshow(M,interpolation='nearest') -pl.xlim((0,n)) + + + #pl.ylim((0,n)) #pl.axis('off') @@ -72,56 +83,13 @@ pl.xlim((0,n)) G0=ot.emd(a,b,M) -#%% plot EMD optimal tranport matrix pl.figure(3) -gs = gridspec.GridSpec(3, 3) - -ax1=pl.subplot(gs[0,1:]) -pl.plot(x,b,'r',label='Target distribution') -pl.yticks(()) - -#pl.axis('off') - -ax2=pl.subplot(gs[1:,0]) -pl.plot(a,x,'b',label='Source distribution') -pl.gca().invert_xaxis() -pl.gca().invert_yaxis() -pl.xticks(()) -#pl.ylim((0,n)) -#pl.axis('off') - -pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) -pl.imshow(G0,interpolation='nearest') - -pl.xlim((0,n)) -#pl.ylim((0,n)) -#pl.axis('off') +plotmat(G0,'OT matrix G0') #%% Sinkhorn -lambd=1e3 +lambd=1e-3 Gs=ot.sinkhorn(a,b,M,lambd) - -#%% plot Sikhorn optimal tranport matrix -pl.figure(3) -gs = gridspec.GridSpec(3, 3) - -ax1=pl.subplot(gs[0,1:]) -pl.plot(x,b,'r',label='Target distribution') -pl.yticks(()) - -#pl.axis('off') - -ax2=pl.subplot(gs[1:,0]) -pl.plot(a,x,'b',label='Source distribution') -pl.gca().invert_xaxis() -pl.gca().invert_yaxis() -pl.xticks(()) -#pl.ylim((0,n)) -#pl.axis('off') - -pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) -pl.imshow(Gs,interpolation='nearest') - -pl.xlim((0,n))
\ No newline at end of file +pl.figure(4) +plotmat(Gs,'OT matrix Sinkhorn') |