summaryrefslogtreecommitdiff
path: root/examples/demo_OT_1D.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-21 11:26:51 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-21 11:26:51 +0200
commit7532b24ca387ff17a63ad1ca2d8099856e38a442 (patch)
treea4d6e6a28e079c1fe0bcf82f7a5389b20870dc17 /examples/demo_OT_1D.py
parent872e6db7c0d110069b450cbe7efcc186c4871428 (diff)
demo with sinkhorn
Diffstat (limited to 'examples/demo_OT_1D.py')
-rw-r--r--examples/demo_OT_1D.py96
1 files changed, 32 insertions, 64 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py
index 865b178..8ba40ba 100644
--- a/examples/demo_OT_1D.py
+++ b/examples/demo_OT_1D.py
@@ -44,27 +44,38 @@ pl.legend()
#%% plot distributions and loss matrix
-pl.figure(2)
-gs = gridspec.GridSpec(3, 3)
-
-ax1=pl.subplot(gs[0,1:])
-pl.plot(x,b,'r',label='Target distribution')
-pl.yticks(())
+def plotmat(M,title=''):
+ """ Plot a matrix woth the 1D distribution """
+ gs = gridspec.GridSpec(3, 3)
+
+ ax1=pl.subplot(gs[0,1:])
+ pl.plot(x,b,'r',label='Target distribution')
+ pl.yticks(())
+ pl.title(title)
+
+ #pl.axis('off')
+
+ ax2=pl.subplot(gs[1:,0])
+ pl.plot(a,x,'b',label='Source distribution')
+ pl.gca().invert_xaxis()
+ pl.gca().invert_yaxis()
+ pl.xticks(())
+ #pl.ylim((0,n))
+ #pl.axis('off')
+
+ pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
+ pl.imshow(M,interpolation='nearest')
+
+ pl.xlim((0,n))
-#pl.axis('off')
+pl.figure(2)
-ax2=pl.subplot(gs[1:,0])
-pl.plot(a,x,'b',label='Source distribution')
-pl.gca().invert_xaxis()
-pl.gca().invert_yaxis()
-pl.xticks(())
-#pl.ylim((0,n))
-#pl.axis('off')
+plotmat(M,'Cost matrix M')
-pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
-pl.imshow(M,interpolation='nearest')
-pl.xlim((0,n))
+
+
+
#pl.ylim((0,n))
#pl.axis('off')
@@ -72,56 +83,13 @@ pl.xlim((0,n))
G0=ot.emd(a,b,M)
-#%% plot EMD optimal tranport matrix
pl.figure(3)
-gs = gridspec.GridSpec(3, 3)
-
-ax1=pl.subplot(gs[0,1:])
-pl.plot(x,b,'r',label='Target distribution')
-pl.yticks(())
-
-#pl.axis('off')
-
-ax2=pl.subplot(gs[1:,0])
-pl.plot(a,x,'b',label='Source distribution')
-pl.gca().invert_xaxis()
-pl.gca().invert_yaxis()
-pl.xticks(())
-#pl.ylim((0,n))
-#pl.axis('off')
-
-pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
-pl.imshow(G0,interpolation='nearest')
-
-pl.xlim((0,n))
-#pl.ylim((0,n))
-#pl.axis('off')
+plotmat(G0,'OT matrix G0')
#%% Sinkhorn
-lambd=1e3
+lambd=1e-3
Gs=ot.sinkhorn(a,b,M,lambd)
-
-#%% plot Sikhorn optimal tranport matrix
-pl.figure(3)
-gs = gridspec.GridSpec(3, 3)
-
-ax1=pl.subplot(gs[0,1:])
-pl.plot(x,b,'r',label='Target distribution')
-pl.yticks(())
-
-#pl.axis('off')
-
-ax2=pl.subplot(gs[1:,0])
-pl.plot(a,x,'b',label='Source distribution')
-pl.gca().invert_xaxis()
-pl.gca().invert_yaxis()
-pl.xticks(())
-#pl.ylim((0,n))
-#pl.axis('off')
-
-pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
-pl.imshow(Gs,interpolation='nearest')
-
-pl.xlim((0,n)) \ No newline at end of file
+pl.figure(4)
+plotmat(Gs,'OT matrix Sinkhorn')