summaryrefslogtreecommitdiff
path: root/examples/barycenters/plot_free_support_barycenter.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/barycenters/plot_free_support_barycenter.py')
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py28
1 files changed, 25 insertions, 3 deletions
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 226dfeb..f4a13dd 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -4,13 +4,14 @@
2D free support Wasserstein barycenters of distributions
========================================================
-Illustration of 2D Wasserstein barycenters if distributions are weighted
+Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
sum of diracs.
"""
# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
# RĂ©mi Flamary <remi.flamary@polytechnique.edu>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -48,7 +49,7 @@ pl.title('Distributions')
# %%
-# Compute free support barycenter
+# Compute free support Wasserstein barycenter
# -------------------------------
k = 200 # number of Diracs of the barycenter
@@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
# %%
-# Plot the barycenter
+# Plot the Wasserstein barycenter
+# ---------
+
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc="lower right")
+pl.show()
+
+# %%
+# Compute free support Sinkhorn barycenter
+
+k = 200 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)
+
+# %%
+# Plot the Wasserstein barycenter
# ---------
pl.figure(2, (8, 3))