summaryrefslogtreecommitdiff
path: root/examples/plot_OT_1D.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2016-12-02 12:49:38 +0100
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2016-12-02 12:57:25 +0100
commitf439f777084690ecbf54bcd8d67dadc883fffa31 (patch)
tree56c8a160fb6edcdaca8ca6ce6de1949b9bc33b77 /examples/plot_OT_1D.py
parent8dbfd3edae649f5f3e87be4a3ce446c59729b2f7 (diff)
first attempt to support sphinx-gallery
Diffstat (limited to 'examples/plot_OT_1D.py')
-rw-r--r--examples/plot_OT_1D.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
new file mode 100644
index 0000000..a8bbbd6
--- /dev/null
+++ b/examples/plot_OT_1D.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+"""
+=============================
+Demo for 1D optimal transport
+=============================
+
+@author: rflamary
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+from ot.datasets import get_1D_gauss as gauss
+
+
+#%% parameters
+
+n=100 # nb bins
+
+# bin positions
+x=np.arange(n,dtype=np.float64)
+
+# Gaussian distributions
+a=gauss(n,m=20,s=5) # m= mean, s= std
+b=gauss(n,m=60,s=10)
+
+# loss matrix
+M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
+M/=M.max()
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.plot(x,a,'b',label='Source distribution')
+pl.plot(x,b,'r',label='Target distribution')
+pl.legend()
+
+#%% plot distributions and loss matrix
+
+pl.figure(2)
+ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
+
+#%% EMD
+
+G0=ot.emd(a,b,M)
+
+pl.figure(3)
+ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
+
+#%% Sinkhorn
+
+lambd=1e-3
+Gs=ot.sinkhorn(a,b,M,lambd)
+
+pl.figure(4)
+ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')