diff options
author | Eduardo Fernandes Montesuma <edumontesuma@gmail.com> | 2022-07-27 11:16:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-07-27 11:16:14 +0200 |
commit | 818c7ace20da36d8042b0d7ad7a712b27f7afd59 (patch) | |
tree | 58dd4e0c9f990ea0c851712d85748de99ce6b236 /examples | |
parent | 7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d (diff) |
[MRG] Free support Sinkhorn barycenters (#387)
* Adding function for computing Sinkhorn Free Support barycenters
* Adding exampel on Free Support Sinkhorn Barycenter
* Fixing typo on free support sinkhorn barycenter example
* Adding info on new Free Support Barycenter solver
* Removing extra line so that code follows pep8
* Fixing issues with pep8 in example
* Correcting issues with pep8 standards
* Adding tests for free support sinkhorn barycenter
* Adding section on Sinkhorn barycenter to the example
* Changing distributions for the Sinkhorn barycenter example
* Removing file that should not be on the last commit
* Adding PR number to REALEASES.md
* Adding new contributors
* Update CONTRIBUTORS.md
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r-- | examples/barycenters/plot_free_support_barycenter.py | 28 | ||||
-rw-r--r-- | examples/barycenters/plot_free_support_sinkhorn_barycenter.py | 151 |
2 files changed, 176 insertions, 3 deletions
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 226dfeb..f4a13dd 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -4,13 +4,14 @@ 2D free support Wasserstein barycenters of distributions ======================================================== -Illustration of 2D Wasserstein barycenters if distributions are weighted +Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted sum of diracs. """ # Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp> # Rémi Flamary <remi.flamary@polytechnique.edu> +# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> # # License: MIT License @@ -48,7 +49,7 @@ pl.title('Distributions') # %% -# Compute free support barycenter +# Compute free support Wasserstein barycenter # ------------------------------- k = 200 # number of Diracs of the barycenter @@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) # %% -# Plot the barycenter +# Plot the Wasserstein barycenter +# --------- + +pl.figure(2, (8, 3)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') +pl.title('Data measures and their barycenter') +pl.legend(loc="lower right") +pl.show() + +# %% +# Compute free support Sinkhorn barycenter + +k = 200 # number of Diracs of the barycenter +X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations +b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) + +X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15) + +# %% +# Plot the Wasserstein barycenter # --------- pl.figure(2, (8, 3)) diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py new file mode 100644 index 0000000..ebe1f3b --- /dev/null +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +""" +======================================================== +2D free support Sinkhorn barycenters of distributions +======================================================== + +Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds + +""" + +# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pyplot as plt +import ot + +# %% +# General Parameters +# ------------------ +reg = 1e-2 # Entropic Regularization +numItermax = 20 # Maximum number of iterations for the Barycenter algorithm +numInnerItermax = 50 # Maximum number of sinkhorn iterations +n_samples = 200 + +# %% +# Generate Data +# ------------- + +X1 = np.random.randn(200, 2) +X2 = 2 * np.concatenate([ + np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), + np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1), +], axis=0) +X3 = np.random.randn(200, 2) +X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None]) +X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200) + +a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)) + +# %% +# Inspect generated distributions +# ------------------------------- + +fig, axes = plt.subplots(1, 4, figsize=(16, 4)) + +axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k') +axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k') +axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k') +axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k') + +axes[0].set_xlim([-3, 3]) +axes[0].set_ylim([-3, 3]) +axes[0].set_title('Distribution 1') + +axes[1].set_xlim([-3, 3]) +axes[1].set_ylim([-3, 3]) +axes[1].set_title('Distribution 2') + +axes[2].set_xlim([-3, 3]) +axes[2].set_ylim([-3, 3]) +axes[2].set_title('Distribution 3') + +axes[3].set_xlim([-3, 3]) +axes[3].set_ylim([-3, 3]) +axes[3].set_title('Distribution 4') + +plt.tight_layout() +plt.show() + +# %% +# Interpolating Empirical Distributions +# ------------------------------------- + +fig = plt.figure(figsize=(10, 10)) + +weights = np.array([ + [3 / 3, 0 / 3], + [2 / 3, 1 / 3], + [1 / 3, 2 / 3], + [0 / 3, 3 / 3], +]).astype(np.float32) + +for k in range(4): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X2], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (0, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X3], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 0)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X3, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (3, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 3, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X2, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 3)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +plt.show() |