summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-21 10:51:27 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-21 10:51:27 +0200
commit581c6de782dca279edd97778cc474e7597788c0f (patch)
tree760161e1c7812d8caf77bf8acc543453c6213e39 /examples
parent2109443f5bea396114d1f9e0563ba5c396378c57 (diff)
demo+sinkhorn
Diffstat (limited to 'examples')
-rw-r--r--examples/demo_OT_1D.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py
new file mode 100644
index 0000000..29f2074
--- /dev/null
+++ b/examples/demo_OT_1D.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Oct 21 09:51:45 2016
+
+@author: rflamary
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+from matplotlib import gridspec
+import ot
+
+
+
+#%% parameters
+
+n=100 # nb bins
+
+ma=20 # mean of a
+mb=60 # mean of b
+
+sa=20 # std of a
+sb=60 # std of b
+
+# bin positions
+x=np.arange(n,dtype=np.float64)
+
+# Gaussian distributions
+a=np.exp(-(x-ma)**2/(2*sa^2))
+b=np.exp(-(x-mb)**2/(2*sb^2))
+
+# normalization
+a/=a.sum()
+b/=b.sum()
+
+# loss matrix
+M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
+M/=M.max()
+
+#%% plot the distributions
+
+pl.figure(1)
+
+pl.plot(x,a,'b',label='Source distribution')
+pl.plot(x,b,'r',label='Target distribution')
+
+pl.legend()
+
+#%% plot distributions and loss matrix
+
+pl.figure(2)
+gs = gridspec.GridSpec(3, 3)
+
+ax1=pl.subplot(gs[0,1:])
+pl.plot(x,b,'r',label='Target distribution')
+pl.yticks(())
+
+#pl.axis('off')
+
+ax2=pl.subplot(gs[1:,0])
+pl.plot(a,x,'b',label='Source distribution')
+pl.gca().invert_xaxis()
+pl.gca().invert_yaxis()
+pl.xticks(())
+#pl.ylim((0,n))
+#pl.axis('off')
+
+pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
+pl.imshow(M,interpolation='nearest')
+
+pl.xlim((0,n))
+#pl.ylim((0,n))
+#pl.axis('off')
+
+#%% EMD
+
+G0=ot.emd(a,b,M)
+
+#%% plot EMD optimal tranport matrix
+pl.figure(3)
+gs = gridspec.GridSpec(3, 3)
+
+ax1=pl.subplot(gs[0,1:])
+pl.plot(x,b,'r',label='Target distribution')
+pl.yticks(())
+
+#pl.axis('off')
+
+ax2=pl.subplot(gs[1:,0])
+pl.plot(a,x,'b',label='Source distribution')
+pl.gca().invert_xaxis()
+pl.gca().invert_yaxis()
+pl.xticks(())
+#pl.ylim((0,n))
+#pl.axis('off')
+
+pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
+pl.imshow(G0,interpolation='nearest')
+
+pl.xlim((0,n))
+#pl.ylim((0,n))
+#pl.axis('off') \ No newline at end of file