diff options
author | Gard Spreemann <gspr@nonempty.org> | 2023-06-14 16:52:13 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2023-06-14 16:52:13 +0200 |
commit | 2b51c7bfcf54d7e17ac7c2514f54408543cbe126 (patch) | |
tree | baf00cc603ceabad00626259eec898e6747d016c /examples/barycenters/plot_free_support_sinkhorn_barycenter.py | |
parent | a49f648b0b07737f7ef315fb83d8f78871780281 (diff) | |
parent | 96788a3fe5601e4c3f49b592aa0d9c034247862e (diff) |
Merge branch 'dfsg/latest'
Diffstat (limited to 'examples/barycenters/plot_free_support_sinkhorn_barycenter.py')
-rw-r--r-- | examples/barycenters/plot_free_support_sinkhorn_barycenter.py | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py new file mode 100644 index 0000000..ebe1f3b --- /dev/null +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +""" +======================================================== +2D free support Sinkhorn barycenters of distributions +======================================================== + +Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds + +""" + +# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pyplot as plt +import ot + +# %% +# General Parameters +# ------------------ +reg = 1e-2 # Entropic Regularization +numItermax = 20 # Maximum number of iterations for the Barycenter algorithm +numInnerItermax = 50 # Maximum number of sinkhorn iterations +n_samples = 200 + +# %% +# Generate Data +# ------------- + +X1 = np.random.randn(200, 2) +X2 = 2 * np.concatenate([ + np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), + np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1), +], axis=0) +X3 = np.random.randn(200, 2) +X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None]) +X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200) + +a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)) + +# %% +# Inspect generated distributions +# ------------------------------- + +fig, axes = plt.subplots(1, 4, figsize=(16, 4)) + +axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k') +axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k') +axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k') +axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k') + +axes[0].set_xlim([-3, 3]) +axes[0].set_ylim([-3, 3]) +axes[0].set_title('Distribution 1') + +axes[1].set_xlim([-3, 3]) +axes[1].set_ylim([-3, 3]) +axes[1].set_title('Distribution 2') + +axes[2].set_xlim([-3, 3]) +axes[2].set_ylim([-3, 3]) +axes[2].set_title('Distribution 3') + +axes[3].set_xlim([-3, 3]) +axes[3].set_ylim([-3, 3]) +axes[3].set_title('Distribution 4') + +plt.tight_layout() +plt.show() + +# %% +# Interpolating Empirical Distributions +# ------------------------------------- + +fig = plt.figure(figsize=(10, 10)) + +weights = np.array([ + [3 / 3, 0 / 3], + [2 / 3, 1 / 3], + [1 / 3, 2 / 3], + [0 / 3, 3 / 3], +]).astype(np.float32) + +for k in range(4): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X2], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (0, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X3], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 0)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X3, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (3, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 3, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X2, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 3)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +plt.show() |