summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-21 11:19:46 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-21 11:19:46 +0200
commit872e6db7c0d110069b450cbe7efcc186c4871428 (patch)
treef499a0258cf69f47a211a54447af990ac0afb591 /examples
parent581c6de782dca279edd97778cc474e7597788c0f (diff)
demo with sinkhorn
Diffstat (limited to 'examples')
-rw-r--r--examples/demo_OT_1D.py39
1 files changed, 32 insertions, 7 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py
index 29f2074..865b178 100644
--- a/examples/demo_OT_1D.py
+++ b/examples/demo_OT_1D.py
@@ -26,12 +26,8 @@ sb=60 # std of b
x=np.arange(n,dtype=np.float64)
# Gaussian distributions
-a=np.exp(-(x-ma)**2/(2*sa^2))
-b=np.exp(-(x-mb)**2/(2*sb^2))
-
-# normalization
-a/=a.sum()
-b/=b.sum()
+a=ot.datasets.get_1D_gauss(n,ma,sa)
+b=ot.datasets.get_1D_gauss(n,mb,sb)
# loss matrix
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
@@ -99,4 +95,33 @@ pl.imshow(G0,interpolation='nearest')
pl.xlim((0,n))
#pl.ylim((0,n))
-#pl.axis('off') \ No newline at end of file
+#pl.axis('off')
+
+#%% Sinkhorn
+lambd=1e3
+
+Gs=ot.sinkhorn(a,b,M,lambd)
+
+
+#%% plot Sikhorn optimal tranport matrix
+pl.figure(3)
+gs = gridspec.GridSpec(3, 3)
+
+ax1=pl.subplot(gs[0,1:])
+pl.plot(x,b,'r',label='Target distribution')
+pl.yticks(())
+
+#pl.axis('off')
+
+ax2=pl.subplot(gs[1:,0])
+pl.plot(a,x,'b',label='Source distribution')
+pl.gca().invert_xaxis()
+pl.gca().invert_yaxis()
+pl.xticks(())
+#pl.ylim((0,n))
+#pl.axis('off')
+
+pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
+pl.imshow(Gs,interpolation='nearest')
+
+pl.xlim((0,n)) \ No newline at end of file