diff options
author | eloitanguy <69361683+eloitanguy@users.noreply.github.com> | 2022-05-06 08:43:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-06 08:43:21 +0200 |
commit | ccc076e0fc535b2c734214c0ac1936e9e2cbeb62 (patch) | |
tree | b5a20af6fabcaefa0de4bc27afd9049bd15612f6 /examples | |
parent | eccb1386eea52b94b82456d126bd20cbe3198e05 (diff) |
[WIP] Generalized Wasserstein Barycenters (#372)
* GWB first solver version
* tests + example for gwb (untested) + free_bar doc fix
* improved doc, fixed minor bugs, better example visu
* minor doc + visu fixes
* plot GWB pep8 fix
* fixed partial gromov test reproductibility
* added an animation for the GWB visu
* added PR num
* minor doc fixes + better gwb logo
Diffstat (limited to 'examples')
-rw-r--r-- | examples/backends/plot_sliced_wass_grad_flow_pytorch.py | 6 | ||||
-rw-r--r-- | examples/barycenters/plot_generalized_free_support_barycenter.py | 152 |
2 files changed, 155 insertions, 3 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index cf5d64d..59e0042 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -103,7 +103,7 @@ ax = pl.axis() # %% # Animate trajectories of the gradient flow along iteration -# ------------------------------------------------------- +# --------------------------------------------------------- pl.figure(3, (8, 4)) @@ -122,7 +122,7 @@ ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, # %% # Compute the Sliced Wasserstein Barycenter -# +# ----------------------------------------- x1_torch = torch.tensor(x1).to(device=device) x3_torch = torch.tensor(x3).to(device=device) xbinit = np.random.randn(500, 2) * 10 + 16 @@ -169,7 +169,7 @@ ax = pl.axis() # %% # Animate trajectories of the barycenter along gradient descent -# ------------------------------------------------------- +# ------------------------------------------------------------- pl.figure(5, (8, 4)) diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py new file mode 100644 index 0000000..9af1953 --- /dev/null +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +""" +======================================= +Generalized Wasserstein Barycenter Demo +======================================= + +This example illustrates the computation of Generalized Wasserstein Barycenter +as proposed in [42]. + + +[42] Delon, J., Gozlan, N., and Saint-Dizier, A.. +Generalized Wasserstein barycenters between probability measures living on different subspaces. +arXiv preprint arXiv:2105.09755, 2021. + +""" + +# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.pylab as pl +import ot +import matplotlib.animation as animation + +######################## +# Generate and plot data +# ---------------------- + +# Input measures +sub_sample_factor = 8 +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] + +sz = I1.shape[0] +UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) + +# Input measure locations in their respective 2D spaces +X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]] + +# Input measure weights +a_list = [ot.unif(x.shape[0]) for x in X_list] + +# Projections 3D -> 2D +P1 = np.array([[1, 0, 0], [0, 1, 0]]) +P2 = np.array([[0, 1, 0], [0, 0, 1]]) +P3 = np.array([[1, 0, 0], [0, 0, 1]]) +P_list = [P1, P2, P3] + +# Barycenter weights +weights = np.array([1 / 3, 1 / 3, 1 / 3]) + +# Number of barycenter points to compute +n_samples_bary = 150 + +# Send the input measures into 3D space for visualisation +X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)] + +# Plot the input data +fig = plt.figure(figsize=(3, 3)) +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + +################################# +# Barycenter computation and plot +# ------------------------------- + +Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary) +fig = plt.figure(figsize=(3, 3)) + +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + + +############################# +# Plotting projection matches +# --------------------------- + +fig = plt.figure(figsize=(9, 3)) + +ax = fig.add_subplot(1, 3, 1, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 2, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=90) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 3, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=90, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +plt.tight_layout() +plt.show() + +############################################## +# Rotation animation +# -------------------------------------------- + +fig = plt.figure(figsize=(7, 7)) +ax = fig.add_subplot(1, 1, 1, projection="3d") + + +def _init(): + for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) + ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.view_init(elev=0, azim=0) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + return fig, + + +def _update_plot(i): + ax.view_init(elev=i, azim=4 * i) + return fig, + + +ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000) |