summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-24 16:32:22 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-24 16:32:22 +0200
commite009eac25a50e1fb26294a739f9c778e3a693c85 (patch)
treea1af36adb84140f5b707ba5ecf2ead5a0be99607 /examples
parent142f51c080572dd2bfe2b2e9434df648fd7ab018 (diff)
add demo barycenter
Diffstat (limited to 'examples')
-rw-r--r--examples/demo_barycenter_1D.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/examples/demo_barycenter_1D.py b/examples/demo_barycenter_1D.py
new file mode 100644
index 0000000..200444b
--- /dev/null
+++ b/examples/demo_barycenter_1D.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Oct 21 09:51:45 2016
+
+@author: rflamary
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+
+#%% parameters
+
+n=100 # nb bins
+
+# bin positions
+x=np.arange(n,dtype=np.float64)
+
+# Gaussian distributions
+a1=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std
+a2=ot.datasets.get_1D_gauss(n,m=60,s=60)
+
+A=np.vstack((a1,a2)).T
+nbd=A.shape[1]
+
+# loss matrix
+M=ot.utils.dist0(n)
+M/=M.max()
+
+#%% plot the distributions
+
+pl.figure(1)
+for i in range(nbd):
+ pl.plot(x,A[:,i])
+pl.title('Distributions')
+
+#%% barucenter computation
+
+# l2bary
+bary_l2=A.mean(1)
+
+# wasserstein
+reg=1e-2
+log=dict()
+bary_wass=ot.bregman.barycenter(A,M,reg,log=log)
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2,1,1)
+for i in range(nbd):
+ pl.plot(x,A[:,i])
+pl.title('Distributions')
+
+pl.subplot(2,1,2)
+pl.plot(x,bary_l2,'r',label='l2')
+pl.plot(x,bary_wass,'g',label='Wasserstein')
+pl.legend()
+pl.title('Barycenters')