From ccc076e0fc535b2c734214c0ac1936e9e2cbeb62 Mon Sep 17 00:00:00 2001 From: eloitanguy <69361683+eloitanguy@users.noreply.github.com> Date: Fri, 6 May 2022 08:43:21 +0200 Subject: [WIP] Generalized Wasserstein Barycenters (#372) * GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo --- .../plot_generalized_free_support_barycenter.py | 152 +++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 examples/barycenters/plot_generalized_free_support_barycenter.py (limited to 'examples/barycenters/plot_generalized_free_support_barycenter.py') diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py new file mode 100644 index 0000000..9af1953 --- /dev/null +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +""" +======================================= +Generalized Wasserstein Barycenter Demo +======================================= + +This example illustrates the computation of Generalized Wasserstein Barycenter +as proposed in [42]. + + +[42] Delon, J., Gozlan, N., and Saint-Dizier, A.. +Generalized Wasserstein barycenters between probability measures living on different subspaces. +arXiv preprint arXiv:2105.09755, 2021. + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.pylab as pl +import ot +import matplotlib.animation as animation + +######################## +# Generate and plot data +# ---------------------- + +# Input measures +sub_sample_factor = 8 +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] + +sz = I1.shape[0] +UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) + +# Input measure locations in their respective 2D spaces +X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]] + +# Input measure weights +a_list = [ot.unif(x.shape[0]) for x in X_list] + +# Projections 3D -> 2D +P1 = np.array([[1, 0, 0], [0, 1, 0]]) +P2 = np.array([[0, 1, 0], [0, 0, 1]]) +P3 = np.array([[1, 0, 0], [0, 0, 1]]) +P_list = [P1, P2, P3] + +# Barycenter weights +weights = np.array([1 / 3, 1 / 3, 1 / 3]) + +# Number of barycenter points to compute +n_samples_bary = 150 + +# Send the input measures into 3D space for visualisation +X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)] + +# Plot the input data +fig = plt.figure(figsize=(3, 3)) +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + +################################# +# Barycenter computation and plot +# ------------------------------- + +Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary) +fig = plt.figure(figsize=(3, 3)) + +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + + +############################# +# Plotting projection matches +# --------------------------- + +fig = plt.figure(figsize=(9, 3)) + +ax = fig.add_subplot(1, 3, 1, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 2, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=90) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 3, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=90, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +plt.tight_layout() +plt.show() + +############################################## +# Rotation animation +# -------------------------------------------- + +fig = plt.figure(figsize=(7, 7)) +ax = fig.add_subplot(1, 1, 1, projection="3d") + + +def _init(): + for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) + ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.view_init(elev=0, azim=0) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + return fig, + + +def _update_plot(i): + ax.view_init(elev=i, azim=4 * i) + return fig, + + +ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000) -- cgit v1.2.3 From d6bf10d8502b1c69f58f009b16634a110053eca1 Mon Sep 17 00:00:00 2001 From: eloitanguy <69361683+eloitanguy@users.noreply.github.com> Date: Wed, 11 May 2022 08:57:54 +0200 Subject: [WIP] Graphical tweaks for GWB + fixed seed method for the partial gromov test (#376) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo * GWB graphical tweaks + better seed method for partial gromov test * fixed PR number * refixed seed issue * seed fix fix fix Co-authored-by: RĂ©mi Flamary --- RELEASES.md | 3 +-- .../barycenters/plot_generalized_free_support_barycenter.py | 11 +++++++---- test/test_partial.py | 10 +++++----- 3 files changed, 13 insertions(+), 11 deletions(-) (limited to 'examples/barycenters/plot_generalized_free_support_barycenter.py') diff --git a/RELEASES.md b/RELEASES.md index c06721f..76385d6 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,14 +4,13 @@ #### New features -- Added Generalized Wasserstein Barycenter solver + example (PR #372) +- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) #### Closed issues - Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU (Issue #371, PR #373) - ## 0.8.2 This releases introduces several new notable features. The less important diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py index 9af1953..e685ec7 100644 --- a/examples/barycenters/plot_generalized_free_support_barycenter.py +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -33,8 +33,8 @@ import matplotlib.animation as animation # Input measures sub_sample_factor = 8 I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] -I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] -I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2] +I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2] sz = I1.shape[0] UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) @@ -145,8 +145,11 @@ def _init(): def _update_plot(i): - ax.view_init(elev=i, azim=4 * i) + if i < 45: + ax.view_init(elev=0, azim=4 * i) + else: + ax.view_init(elev=i - 45, azim=4 * i) return fig, -ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000) +ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000) diff --git a/test/test_partial.py b/test/test_partial.py index e07377b..33fc259 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -137,7 +137,7 @@ def test_partial_wasserstein(): def test_partial_gromov_wasserstein(): - np.random.seed(42) + rng = np.random.RandomState(seed=42) n_samples = 20 # nb samples n_noise = 10 # nb of samples (noise) @@ -150,11 +150,11 @@ def test_partial_gromov_wasserstein(): mu_t = np.array([0, 0, 0]) cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) - xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, rng) + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) P = sp.linalg.sqrtm(cov_t) - xt = np.random.randn(n_samples, 3).dot(P) + mu_t - xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) xt2 = xs[::-1].copy() C1 = ot.dist(xs, xs) -- cgit v1.2.3