From 3f23fa1a950ffde4a4224a6343a504a0c5b7851b Mon Sep 17 00:00:00 2001 From: vivienseguy Date: Thu, 5 Jul 2018 17:35:27 +0900 Subject: free support barycenter --- examples/plot_free_support_barycenter.py | 66 ++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 examples/plot_free_support_barycenter.py (limited to 'examples/plot_free_support_barycenter.py') diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py new file mode 100644 index 0000000..a733745 --- /dev/null +++ b/examples/plot_free_support_barycenter.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +2D Wasserstein barycenters between empirical distributions +==================================================== + +Illustration of 2D Wasserstein barycenters between discributions that are weighted +sum of diracs. + +""" + +# Author: Vivien Seguy +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot.plot + + +############################################################################## +# Generate data +# ------------- +#%% parameters and data generation +N = 4 +d = 2 +measures_locations = [] +measures_weights = [] + +for i in range(N): + + n = np.rand.int(low=1, high=20) # nb samples + + mu = np.random.normal(0., 1., (d,)) + cov = np.random.normal(0., 1., (d,d)) + + xs = ot.datasets.make_2D_samples_gauss(n, mu, cov) + b = np.random.uniform(0., 1., n) + b = b/np.sum(b) + + measures_locations.append(xs) + measures_weights.append(b) + +k = 10 +X_init = np.random.normal(0., 1., (k,d)) +b_init = np.ones((k,)) / k + + +############################################################################## +# Compute free support barycenter +# ------------- +X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init) + + +############################################################################## +# Plot data +# --------- + +#%% plot samples + +pl.figure(1) +for (xs, b) in zip(measures_locations, measures_weights): + pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') +pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter') +pl.legend(loc=0) +pl.title('Data measures and their barycenter') -- cgit v1.2.3