summaryrefslogtreecommitdiff
path: root/examples/barycenters/plot_free_support_barycenter.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/barycenters/plot_free_support_barycenter.py')
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py55
1 files changed, 28 insertions, 27 deletions
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 2d68a39..226dfeb 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -9,61 +9,62 @@ sum of diracs.
"""
-# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# RĂ©mi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import matplotlib.pylab as pl
import ot
-##############################################################################
+# %%
# Generate data
# -------------
-N = 3
+N = 2
d = 2
-measures_locations = []
-measures_weights = []
-
-for i in range(N):
- n_i = np.random.randint(low=1, high=20) # nb samples
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2]
- mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
- A_i = np.random.rand(d, d)
- cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
- x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
- b_i = np.random.uniform(0., 1., (n_i,))
- b_i = b_i / np.sum(b_i) # Dirac weights
+measures_locations = [x1, x2]
+measures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])]
- measures_locations.append(x_i)
- measures_weights.append(b_i)
+pl.figure(1, (12, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.title('Distributions')
-##############################################################################
+# %%
# Compute free support barycenter
# -------------------------------
-k = 10 # number of Diracs of the barycenter
+k = 200 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
-
-##############################################################################
-# Plot data
+# %%
+# Plot the barycenter
# ---------
-pl.figure(1)
-for (x_i, b_i) in zip(measures_locations, measures_weights):
- color = np.random.randint(low=1, high=10 * N)
- pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure')
-pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
-pl.legend(loc=0)
+pl.legend(loc="lower right")
pl.show()