diff options
Diffstat (limited to 'examples')
17 files changed, 1344 insertions, 42 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index cf5d64d..f00de50 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -74,7 +74,7 @@ x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) loss_iter = [] # generator for random permutations -gen = torch.Generator() +gen = torch.Generator(device=device) gen.manual_seed(42) for i in range(nb_iter_max): @@ -103,7 +103,7 @@ ax = pl.axis() # %% # Animate trajectories of the gradient flow along iteration -# ------------------------------------------------------- +# --------------------------------------------------------- pl.figure(3, (8, 4)) @@ -122,7 +122,7 @@ ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, # %% # Compute the Sliced Wasserstein Barycenter -# +# ----------------------------------------- x1_torch = torch.tensor(x1).to(device=device) x3_torch = torch.tensor(x3).to(device=device) xbinit = np.random.randn(500, 2) * 10 + 16 @@ -136,7 +136,7 @@ x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2)) loss_iter = [] # generator for random permutations -gen = torch.Generator() +gen = torch.Generator(device=device) gen.manual_seed(42) alpha = 0.5 @@ -169,7 +169,7 @@ ax = pl.axis() # %% # Animate trajectories of the barycenter along gradient descent -# ------------------------------------------------------- +# ------------------------------------------------------------- pl.figure(5, (8, 4)) diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py new file mode 100644 index 0000000..7ccc2af --- /dev/null +++ b/examples/backends/plot_ssw_unif_torch.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +r""" +================================================ +Spherical Sliced-Wasserstein Embedding on Sphere +================================================ + +Here, we aim at transforming samples into a uniform +distribution on the sphere by minimizing SSW: + +.. math:: + \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i}) + +where :math:`\nu=\mathrm{Unif}(S^1)`. + +""" + +# Author: Clément Bonet <clement.bonet@univ-ubs.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import matplotlib.animation as animation +import torch +import torch.nn.functional as F + +import ot + + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +N = 1000 +x0 = torch.rand(N, 3) +x0 = F.normalize(x0, dim=-1) + + +# %% +# Plot data +# --------- + +def plot_sphere(ax): + xlist = np.linspace(-1.0, 1.0, 50) + ylist = np.linspace(-1.0, 1.0, 50) + r = np.linspace(1.0, 1.0, 50) + X, Y = np.meshgrid(xlist, ylist) + + Z = np.sqrt(np.maximum(r**2 - X**2 - Y**2, 0)) + + ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3) + ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half + + +# plot the distributions +pl.figure(1) +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5) +ax.set_title('Data distribution') +ax.legend() + + +# %% +# Gradient descent +# ---------------- + +x = x0.clone() +x.requires_grad_(True) + +n_iter = 500 +lr = 100 + +losses = [] +xvisu = torch.zeros(n_iter, N, 3) + +for i in range(n_iter): + sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) + grad_x = torch.autograd.grad(sw, x)[0] + + x = x - lr * grad_x + x = F.normalize(x, p=2, dim=1) + + losses.append(sw.item()) + xvisu[i, :, :] = x.detach().clone() + + if i % 100 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + +pl.figure(1) +pl.semilogy(losses) +pl.grid() +pl.title('SSW') +pl.xlabel("Iterations") + + +# %% +# Plot trajectories of generated samples along iterations +# ------------------------------------------------------- + +ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499] + +fig = pl.figure(3, (10, 10)) +for i in range(9): + # pl.subplot(3, 3, i + 1) + # ax = pl.axes(projection='3d') + ax = fig.add_subplot(3, 3, i + 1, projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5) + ax.set_title('Iter. {}'.format(ivisu[i])) + #ax.axis("off") + if i == 0: + ax.legend() + + +# %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + i = 3 * i + pl.clf() + ax = pl.axes(projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5) + ax.axis("off") + ax.set_xlim((-1.5, 1.5)) + ax.set_ylim((-1.5, 1.5)) + ax.set_title('Iter. {}'.format(i)) + return 1 + + +print(xvisu.shape) + +i = 0 +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5) +ax.axis("off") +ax.set_xlim((-1.5, 1.5)) +ax.set_ylim((-1.5, 1.5)) +ax.set_title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000) +# %% diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 2373e99..8096245 100644 --- a/examples/barycenters/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -106,7 +106,7 @@ for i, z in enumerate(zs): ys = B_l2[:, i] verts.append(list(zip(x, ys))) -ax = plt.gcf().gca(projection='3d') +ax = plt.gcf().add_subplot(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -128,7 +128,7 @@ for i, z in enumerate(zs): ys = B_wass[:, i] verts.append(list(zip(x, ys))) -ax = plt.gcf().gca(projection='3d') +ax = plt.gcf().add_subplot(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 226dfeb..f4a13dd 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -4,13 +4,14 @@ 2D free support Wasserstein barycenters of distributions ======================================================== -Illustration of 2D Wasserstein barycenters if distributions are weighted +Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted sum of diracs. """ # Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp> # Rémi Flamary <remi.flamary@polytechnique.edu> +# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> # # License: MIT License @@ -48,7 +49,7 @@ pl.title('Distributions') # %% -# Compute free support barycenter +# Compute free support Wasserstein barycenter # ------------------------------- k = 200 # number of Diracs of the barycenter @@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) # %% -# Plot the barycenter +# Plot the Wasserstein barycenter +# --------- + +pl.figure(2, (8, 3)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') +pl.title('Data measures and their barycenter') +pl.legend(loc="lower right") +pl.show() + +# %% +# Compute free support Sinkhorn barycenter + +k = 200 # number of Diracs of the barycenter +X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations +b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) + +X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15) + +# %% +# Plot the Wasserstein barycenter # --------- pl.figure(2, (8, 3)) diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py new file mode 100644 index 0000000..ebe1f3b --- /dev/null +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +""" +======================================================== +2D free support Sinkhorn barycenters of distributions +======================================================== + +Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds + +""" + +# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pyplot as plt +import ot + +# %% +# General Parameters +# ------------------ +reg = 1e-2 # Entropic Regularization +numItermax = 20 # Maximum number of iterations for the Barycenter algorithm +numInnerItermax = 50 # Maximum number of sinkhorn iterations +n_samples = 200 + +# %% +# Generate Data +# ------------- + +X1 = np.random.randn(200, 2) +X2 = 2 * np.concatenate([ + np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), + np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1), +], axis=0) +X3 = np.random.randn(200, 2) +X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None]) +X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200) + +a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)) + +# %% +# Inspect generated distributions +# ------------------------------- + +fig, axes = plt.subplots(1, 4, figsize=(16, 4)) + +axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k') +axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k') +axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k') +axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k') + +axes[0].set_xlim([-3, 3]) +axes[0].set_ylim([-3, 3]) +axes[0].set_title('Distribution 1') + +axes[1].set_xlim([-3, 3]) +axes[1].set_ylim([-3, 3]) +axes[1].set_title('Distribution 2') + +axes[2].set_xlim([-3, 3]) +axes[2].set_ylim([-3, 3]) +axes[2].set_title('Distribution 3') + +axes[3].set_xlim([-3, 3]) +axes[3].set_ylim([-3, 3]) +axes[3].set_title('Distribution 4') + +plt.tight_layout() +plt.show() + +# %% +# Interpolating Empirical Distributions +# ------------------------------------- + +fig = plt.figure(figsize=(10, 10)) + +weights = np.array([ + [3 / 3, 0 / 3], + [2 / 3, 1 / 3], + [1 / 3, 2 / 3], + [0 / 3, 3 / 3], +]).astype(np.float32) + +for k in range(4): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X2], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (0, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X1, X3], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 0)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 4, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X3, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (3, k)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +for k in range(1, 3, 1): + XB_init = np.random.randn(n_samples, 2) + XB = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations=[X2, X4], + measures_weights=[a1, a2], + weights=weights[k], + X_init=XB_init, + reg=reg, + numItermax=numItermax, + numInnerItermax=numInnerItermax + ) + ax = plt.subplot2grid((4, 4), (k, 3)) + ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.set_xlim([-3, 3]) + ax.set_ylim([-3, 3]) + +plt.show() diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py new file mode 100644 index 0000000..e685ec7 --- /dev/null +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +""" +======================================= +Generalized Wasserstein Barycenter Demo +======================================= + +This example illustrates the computation of Generalized Wasserstein Barycenter +as proposed in [42]. + + +[42] Delon, J., Gozlan, N., and Saint-Dizier, A.. +Generalized Wasserstein barycenters between probability measures living on different subspaces. +arXiv preprint arXiv:2105.09755, 2021. + +""" + +# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.pylab as pl +import ot +import matplotlib.animation as animation + +######################## +# Generate and plot data +# ---------------------- + +# Input measures +sub_sample_factor = 8 +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2] +I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2] + +sz = I1.shape[0] +UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) + +# Input measure locations in their respective 2D spaces +X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]] + +# Input measure weights +a_list = [ot.unif(x.shape[0]) for x in X_list] + +# Projections 3D -> 2D +P1 = np.array([[1, 0, 0], [0, 1, 0]]) +P2 = np.array([[0, 1, 0], [0, 0, 1]]) +P3 = np.array([[1, 0, 0], [0, 0, 1]]) +P_list = [P1, P2, P3] + +# Barycenter weights +weights = np.array([1 / 3, 1 / 3, 1 / 3]) + +# Number of barycenter points to compute +n_samples_bary = 150 + +# Send the input measures into 3D space for visualisation +X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)] + +# Plot the input data +fig = plt.figure(figsize=(3, 3)) +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + +################################# +# Barycenter computation and plot +# ------------------------------- + +Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary) +fig = plt.figure(figsize=(3, 3)) + +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + + +############################# +# Plotting projection matches +# --------------------------- + +fig = plt.figure(figsize=(9, 3)) + +ax = fig.add_subplot(1, 3, 1, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 2, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=90) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 3, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=90, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +plt.tight_layout() +plt.show() + +############################################## +# Rotation animation +# -------------------------------------------- + +fig = plt.figure(figsize=(7, 7)) +ax = fig.add_subplot(1, 1, 1, projection="3d") + + +def _init(): + for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) + ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.view_init(elev=0, azim=0) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + return fig, + + +def _update_plot(i): + if i < 45: + ax.view_init(elev=0, azim=4 * i) + else: + ax.view_init(elev=i - 45, azim=4 * i) + return fig, + + +ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000) diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index a44096a..8284a2a 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -61,7 +61,7 @@ plt.plot(xt[:, 0], xt[:, 1], 'o') # Estimate linear mapping and transport # ------------------------------------- -Ae, be = ot.da.OT_mapping_linear(xs, xt) +Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt) xst = xs.dot(Ae) + be diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py index 556e08f..dc3c6aa 100644 --- a/examples/gromov/plot_barycenter_fgw.py +++ b/examples/gromov/plot_barycenter_fgw.py @@ -174,7 +174,7 @@ A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True) # ------------------------- #%% Create the barycenter -bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) +bary = nx.from_numpy_array(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) for i, v in enumerate(A.ravel()): bary.add_node(i, attr_name=v) diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index 5a362cf..05074dc 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -3,7 +3,6 @@ ==========================
Gromov-Wasserstein example
==========================
-
This example is designed to show how to use the Gromov-Wassertsein distance
computation in POT.
"""
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index 7fe081f..08ec610 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -110,8 +110,7 @@ for nb in range(4): if shapes[nb][i, j] < 0.95:
xs[nb].append([j, 8 - i])
-xs = np.array([np.array(xs[0]), np.array(xs[1]),
- np.array(xs[2]), np.array(xs[3])])
+xs = [np.array(xs[s]) for s in range(S)]
##############################################################################
# Barycenter computation
diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py index 1fdc3b9..7585944 100755 --- a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py +++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py @@ -45,10 +45,11 @@ from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dic import ot import networkx from networkx.generators.community import stochastic_block_model as sbm -# %% -# ============================================================================= + +############################################################################# +# # Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. -# ============================================================================= +# --------------------------------------------- np.random.seed(42) @@ -109,10 +110,10 @@ for idx_c, c in enumerate(clusters): pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Estimate the gromov-wasserstein dictionary from the dataset -# ============================================================================= +# --------------------------------------------- np.random.seed(0) @@ -140,10 +141,10 @@ pl.ylabel('loss', fontsize=12) pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the estimated dictionary atoms -# ============================================================================= +# --------------------------------------------- # Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white) @@ -164,10 +165,11 @@ for idx_atom, atom in enumerate(Cdict_GW): pl.axis("off") pl.tight_layout() pl.show() -#%% -# ============================================================================= + +############################################################################# +# # Visualization of the embedding space -# ============================================================================= +# --------------------------------------------- unmixings = [] reconstruction_errors = [] @@ -211,11 +213,11 @@ pl.axis('off') pl.legend(fontsize=11) pl.tight_layout() pl.show() -# %% -# ============================================================================= -# Endow the dataset with node features -# ============================================================================= +############################################################################# +# +# Endow the dataset with node features +# --------------------------------------------- # We follow this feature assignment on all nodes of a graph depending on its label/number of clusters # 1 cluster --> 0 as nodes feature # 2 clusters --> 1 as nodes feature @@ -251,10 +253,11 @@ for idx_c, c in enumerate(clusters): pl.axis("off") pl.tight_layout() pl.show() -# %% -# ============================================================================= + +############################################################################# +# # Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs -# ============================================================================= +# --------------------------------------------- np.random.seed(0) ps = [ot.unif(C.shape[0]) for C in dataset] D = 3 # 6 atoms instead of 3 @@ -280,10 +283,10 @@ pl.ylabel('loss', fontsize=12) pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the estimated dictionary atoms -# ============================================================================= +# --------------------------------------------- pl.figure(7, (12, 8)) pl.clf() @@ -307,10 +310,10 @@ for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the embedding space -# ============================================================================= +# --------------------------------------------- unmixings = [] reconstruction_errors = [] diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py new file mode 100644 index 0000000..ef4b286 --- /dev/null +++ b/examples/gromov/plot_semirelaxed_fgw.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- +""" +========================== +Semi-relaxed (Fused) Gromov-Wasserstein example +========================== + +This example is designed to show how to use the semi-relaxed Gromov-Wasserstein +and the semi-relaxed Fused Gromov-Wasserstein divergences. + +sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of +G2 at a minimal (F)GW distance from G1. + +First, we generate two graphs following Stochastic Block Models, then show +how to compute their srGW matchings and illustrate them. These graphs are then +endowed with node features and we follow the same process with srFGW. + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. +"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" +International Conference on Learning Representations (ICLR), 2021. +""" + +# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pylab as pl +from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +############################################################################# +# +# Generate two graphs following Stochastic Block models of 2 and 3 clusters. +# --------------------------------------------- + + +N2 = 20 # 2 communities +N3 = 30 # 3 communities +p2 = [[1., 0.1], + [0.1, 0.9]] +p3 = [[1., 0.1, 0.], + [0.1, 0.95, 0.1], + [0., 0.1, 0.9]] +G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) +G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) + + +C2 = networkx.to_numpy_array(G2) +C3 = networkx.to_numpy_array(G3) + +h2 = np.ones(C2.shape[0]) / C2.shape[0] +h3 = np.ones(C3.shape[0]) / C3.shape[0] + +# Add weights on the edges for visualization later on +weight_intra_G2 = 5 +weight_inter_G2 = 0.5 +weight_intra_G3 = 1. +weight_inter_G3 = 1.5 + +weightedG2 = networkx.Graph() +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] + +for node in G2.nodes(): + weightedG2.add_node(node) +for i, j in G2.edges(): + if part_G2[i] == part_G2[j]: + weightedG2.add_edge(i, j, weight=weight_intra_G2) + else: + weightedG2.add_edge(i, j, weight=weight_inter_G2) + +weightedG3 = networkx.Graph() +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +for node in G3.nodes(): + weightedG3.add_node(node) +for i, j in G3.edges(): + if part_G3[i] == part_G3[j]: + weightedG3.add_edge(i, j, weight=weight_intra_G3) + else: + weightedG3.add_edge(i, j, weight=weight_inter_G3) + +############################################################################# +# +# Compute their semi-relaxed Gromov-Wasserstein divergences +# --------------------------------------------- + +# 0) GW(C2, h2, C3, h3) for reference +OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) +gw = log['gw_dist'] + +# 1) srGW(C2, h2, C3) +OT_23, log_23 = semirelaxed_gromov_wasserstein(C2, C3, h2, symmetric=True, + log=True, G0=None) +srgw_23 = log_23['srgw_dist'] + +# 2) srGW(C3, h3, C2) + +OT_32, log_32 = semirelaxed_gromov_wasserstein(C3, C2, h3, symmetric=None, + log=True, G0=OT.T) +srgw_32 = log_32['srgw_dist'] + +print('GW(C2, C3) = ', gw) +print('srGW(C2, h2, C3) = ', srgw_23) +print('srGW(C3, h3, C2) = ', srgw_32) + + +############################################################################# +# +# Visualization of the semi-relaxed Gromov-Wasserstein matchings +# --------------------------------------------- +# +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the srGW matching + + +def draw_graph(G, C, nodes_color_part, Gweights=None, + pos=None, edge_color='black', node_size=None, + shiftx=0, seed=0): + + if (pos is None): + pos = networkx.spring_layout(G, scale=1., seed=seed) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, + width=width_edge, alpha=alpha_edge, + edge_color=edge_color) + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, + width=width_edge, alpha=0.1, + edge_color=edge_color) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=node_size, alpha=1, + node_color=node_color) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=nodes_size[node], alpha=1, + node_color=node_color) + return pos + + +def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, + p1, p2, T, pos1=None, pos2=None, + shiftx=4, switchx=False, node_size=70, + seed_G1=0, seed_G2=0): + starting_color = 0 + # get graphs partition and their coloring + part1 = part_G1.copy() + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + nodes_color_part1 = [] + for cluster in part1: + nodes_color_part1.append(unique_colors[cluster]) + + nodes_color_part2 = [] + # T: getting colors assignment from argmin of columns + for i in range(len(G2.nodes())): + j = np.argmax(T[:, i]) + nodes_color_part2.append(nodes_color_part1[j]) + pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, + pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) + pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, + node_size=node_size, shiftx=shiftx, seed=seed_G2) + for k1, v1 in pos1.items(): + for k2, v2 in pos2.items(): + if (T[k1, k2] > 0): + pl.plot([pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + '-', lw=0.8, alpha=0.5, + color=nodes_color_part1[k1]) + return pos1, pos2 + + +node_size = 40 +fontsize = 10 +seed_G2 = 0 +seed_G3 = 4 + +pl.figure(1, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() + +############################################################################# +# +# Add node features +# --------------------------------------------- + +# We add node features with given mean - by clusters +# and inversely proportional to clusters' intra-connectivity + +F2 = np.zeros((N2, 1)) +for i, c in enumerate(part_G2): + F2[i, 0] = np.random.normal(loc=c, scale=0.01) + +F3 = np.zeros((N3, 1)) +for i, c in enumerate(part_G3): + F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + +############################################################################# +# +# Compute their semi-relaxed Fused Gromov-Wasserstein divergences +# --------------------------------------------- + +alpha = 0.5 +# Compute pairwise euclidean distance between node features +M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) + +# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference + +OT, log = fused_gromov_wasserstein( + M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True) +fgw = log['fgw_dist'] + +# 1) srFGW(C2, F2, h2, C3, F3) +OT_23, log_23 = semirelaxed_fused_gromov_wasserstein( + M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None) +srfgw_23 = log_23['srfgw_dist'] + +# 2) srFGW(C3, F3, h3, C2, F2) + +OT_32, log_32 = semirelaxed_fused_gromov_wasserstein( + M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None) +srfgw_32 = log_32['srfgw_dist'] + +print('FGW(C2, F2, C3, F3) = ', fgw) +print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23) +print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32) + +############################################################################# +# +# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings +# --------------------------------------------- +# +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the srFGW matching +# NB: colors refer to clusters - not to node features + +pl.figure(2, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() diff --git a/examples/others/plot_COOT.py b/examples/others/plot_COOT.py new file mode 100644 index 0000000..98c1ce1 --- /dev/null +++ b/examples/others/plot_COOT.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +r""" +=================================================== +Row and column alignments with CO-Optimal Transport +=================================================== + +This example is designed to show how to use the CO-Optimal Transport [47]_ in POT. +CO-Optimal Transport allows to calculate the distance between two **arbitrary-size** +matrices, and to align their rows and columns. In this example, we consider two +random matrices :math:`X_1` and :math:`X_2` defined by +:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)` +and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`. + +.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). + `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_. + Advances in Neural Information Processing Systems, 33. +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Quang Huy Tran <quang-huy.tran@univ-ubs.fr> +# License: MIT License + +from matplotlib.patches import ConnectionPatch +import matplotlib.pylab as pl +import numpy as np +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 + +# %% +# Generating two random matrices + +n1 = 20 +n2 = 10 +d1 = 16 +d2 = 8 +sigma = 0.2 + +X1 = ( + np.cos(np.arange(n1) * np.pi / n1)[:, None] + + np.cos(np.arange(d1) * np.pi / d1)[None, :] + + sigma * np.random.randn(n1, d1) +) +X2 = ( + np.cos(np.arange(n2) * np.pi / n2)[:, None] + + np.cos(np.arange(d2) * np.pi / d2)[None, :] + + sigma * np.random.randn(n2, d2) +) + +# %% +# Visualizing the matrices + +pl.figure(1, (8, 5)) +pl.subplot(1, 2, 1) +pl.imshow(X1) +pl.title('$X_1$') + +pl.subplot(1, 2, 2) +pl.imshow(X2) +pl.title("$X_2$") + +pl.tight_layout() + +# %% +# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance + +pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True) +coot_distance = coot2(X1, X2) +print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance)) + +fig = pl.figure(4, (9, 7)) +pl.clf() + +ax1 = pl.subplot(2, 2, 3) +pl.imshow(X1) +pl.xlabel('$X_1$') + +ax2 = pl.subplot(2, 2, 2) +ax2.yaxis.tick_right() +pl.imshow(np.transpose(X2)) +pl.title("Transpose($X_2$)") +ax2.xaxis.tick_top() + +for i in range(n1): + j = np.argmax(pi_sample[i, :]) + xyA = (d1 - .5, i) + xyB = (j, d2 - .5) + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, + coordsB=ax2.transData, color="black") + fig.add_artist(con) + +for i in range(d1): + j = np.argmax(pi_feature[i, :]) + xyA = (i, -.5) + xyB = (-.5, j) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + fig.add_artist(con) diff --git a/examples/others/plot_learning_weights_with_COOT.py b/examples/others/plot_learning_weights_with_COOT.py new file mode 100644 index 0000000..cb115c3 --- /dev/null +++ b/examples/others/plot_learning_weights_with_COOT.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +r""" +=============================================================== +Learning sample marginal distribution with CO-Optimal Transport +=============================================================== + +In this example, we illustrate how to estimate the sample marginal distribution which minimizes +the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data +:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed +histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem + +.. math:: + \min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right) + +where :math:`\Delta` is the probability simplex. This minimization is done with a +simple projected gradient descent in PyTorch. We use the automatic backend of POT that +allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2` +with differentiable losses. + +.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). + `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_. + Advances in Neural Information Processing Systems, 33. +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Quang Huy Tran <quang-huy.tran@univ-ubs.fr> +# License: MIT License + +from matplotlib.patches import ConnectionPatch +import torch +import numpy as np + +import matplotlib.pyplot as pl +import ot + +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 + + +# %% +# Generate data +# ------------- +# The source and clean target matrices are generated by +# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and +# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`. +# The target matrix is then contaminated by adding 5 row outliers. +# Intuitively, we expect that the estimated sample distribution should ignore these outliers, +# i.e. their weights should be zero. + +np.random.seed(182) + +n1, d1 = 20, 16 +n2, d2 = 10, 8 +n = 15 + +X = ( + torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] + + torch.cos(torch.arange(d1) * torch.pi / d1)[None, :] +) + +# Generate clean target data mixed with outliers +Y_noisy = torch.randn((n, d2)) * 10.0 +Y_noisy[:n2, :] = ( + torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] + + torch.cos(torch.arange(d2) * torch.pi / d2)[None, :] +) +Y = Y_noisy[:n2, :] + +X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double() + +fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5)) +axes[0].imshow(X, vmin=-2, vmax=2) +axes[0].set_title('$X$') + +axes[1].imshow(Y, vmin=-2, vmax=2) +axes[1].set_title('Clean $Y$') + +axes[2].imshow(Y_noisy, vmin=-2, vmax=2) +axes[2].set_title('Noisy $Y$') + +pl.tight_layout() + +# %% +# Optimize the COOT distance with respect to the sample marginal distribution +# --------------------------------------------------------------------------- + +losses = [] +lr = 1e-3 +niter = 1000 + +b = torch.tensor(ot.unif(n), requires_grad=True) + +for i in range(niter): + + loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False) + losses.append(float(loss)) + + loss.backward() + + with torch.no_grad(): + b -= lr * b.grad # gradient step + b[:] = ot.utils.proj_simplex(b) # projection on the simplex + + b.grad.zero_() + +# Estimated sample marginal distribution and training loss curve +pl.plot(losses[10:]) +pl.title('CO-Optimal Transport distance') + +print(f"Marginal distribution = {b.detach().numpy()}") + +# %% +# Visualizing the row and column alignments with the estimated sample marginal distribution +# ----------------------------------------------------------------------------------------- +# +# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers. + +X, Y_noisy = X.numpy(), Y_noisy.numpy() +b = b.detach().numpy() + +pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True) + +fig = pl.figure(4, (9, 7)) +pl.clf() + +ax1 = pl.subplot(2, 2, 3) +pl.imshow(X, vmin=-2, vmax=2) +pl.xlabel('$X$') + +ax2 = pl.subplot(2, 2, 2) +ax2.yaxis.tick_right() +pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2) +pl.title("Transpose(Noisy $Y$)") +ax2.xaxis.tick_top() + +for i in range(n1): + j = np.argmax(pi_sample[i, :]) + xyA = (d1 - .5, i) + xyB = (j, d2 - .5) + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, + coordsB=ax2.transData, color="black") + fig.add_artist(con) + +for i in range(d1): + j = np.argmax(pi_feature[i, :]) + xyA = (i, -.5) + xyB = (-.5, j) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + fig.add_artist(con) diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py new file mode 100644 index 0000000..3ede96f --- /dev/null +++ b/examples/plot_compute_wasserstein_circle.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +""" +========================= +OT distance on the Circle +========================= + +Shows how to compute the Wasserstein distance on the circle + + +""" + +# Author: Clément Bonet <clement.bonet@univ-ubs.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot + +from scipy.special import iv + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + + +def pdf_von_Mises(theta, mu, kappa): + pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa)) + return pdf + + +t = np.linspace(0, 2 * np.pi, 1000, endpoint=False) + +mu1 = 1 +kappa1 = 20 + +mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10) + + +pdf1 = pdf_von_Mises(t, mu1, kappa1) + + +pl.figure(1) +for k, mu in enumerate(mu_targets): + pdf_t = pdf_von_Mises(t, mu, kappa1) + if k == 0: + label = "Source distributions" + else: + label = None + pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label) + +pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution") +pl.legend() + +mu2 = 0 +kappa2 = kappa1 + +x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi +x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi + +angles = np.linspace(0, 2 * np.pi, 150) + +pl.figure(2) +pl.plot(np.cos(angles), np.sin(angles), c="k") +pl.xlim(-1.25, 1.25) +pl.ylim(-1.25, 1.25) +pl.scatter(np.cos(x1), np.sin(x1), c="b") +pl.scatter(np.cos(x2), np.sin(x2), c="r") + +######################################################################################### +# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle +# --------------------------------------------------------------------------------------- +# This examples illustrates the periodicity of the Wasserstein distance on the circle. +# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}` +# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution +# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`. +# The Wasserstein distance on the circle takes into account the periodicity +# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the +# Euclidean version. + +#%% Compute and plot distributions + +mu_targets = np.linspace(0, 2 * np.pi, 200) +xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi + +n_try = 5 + +xts = np.zeros((n_try, 200, 500)) +for i in range(n_try): + for k, mu in enumerate(mu_targets): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi + xts[i, k] = xt + +# Put data on S^1=[0,1[ +xts2 = xts / (2 * np.pi) +xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi) + +L_w2_circle = np.zeros((n_try, 200)) +L_w2 = np.zeros((n_try, 200)) + +for i in range(n_try): + w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2) + w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2) + + L_w2_circle[i] = w2_circle + L_w2[i] = w2 + +m_w2_circle = np.mean(L_w2_circle, axis=0) +std_w2_circle = np.std(L_w2_circle, axis=0) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle") +pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5) +pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein") +pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5) +pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$") +pl.legend() +pl.xlabel(r"$\mu_{\mathrm{source}}$") +pl.show() + + +######################################################################## +# Wasserstein distance between von Mises and uniform for different kappa +# ---------------------------------------------------------------------- +# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`. + +#%% Compute Wasserstein between Von Mises and uniform + +kappas = np.logspace(-5, 2, 100) +n_try = 20 + +xts = np.zeros((n_try, 100, 500)) +for i in range(n_try): + for k, kappa in enumerate(kappas): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi + xts[i, k] = xt / (2 * np.pi) + +L_w2 = np.zeros((n_try, 100)) +for i in range(n_try): + L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(kappas, m_w2) +pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5) +pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$") +pl.xlabel(r"$\kappa$") +pl.show() + +# %% diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py new file mode 100644 index 0000000..83d458f --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Spherical Sliced Wasserstein on distributions in S^2 +==================================================== + +This example illustrates the computation of the spherical sliced Wasserstein discrepancy as +proposed in [46]. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations. + +""" + +# Author: Clément Bonet <clement.bonet@univ-ubs.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +xs = np.random.randn(n, 3) +xt = np.random.randn(n, 3) + +xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True)) +xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True)) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +fig = pl.figure(figsize=(10, 10)) +ax = pl.axes(projection='3d') +ax.grid(False) + +u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j] +x = np.cos(u) * np.sin(v) +y = np.sin(u) * np.sin(v) +z = np.cos(v) +ax.plot_surface(x, y, z, color="gray", alpha=0.03) +ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray") + +ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source") +ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target") + +fs = 10 +# Labels +ax.set_xlabel('x', fontsize=fs) +ax.set_ylabel('y', fontsize=fs) +ax.set_zlabel('z', fontsize=fs) + +ax.view_init(20, 120) +ax.set_xlim(-1.5, 1.5) +ax.set_ylim(-1.5, 1.5) +ax.set_zlim(-1.5, 1.5) + +# Ticks +ax.set_xticks([-1, 0, 1]) +ax.set_yticks([-1, 0, 1]) +ax.set_zticks([-1, 0, 1]) + +pl.legend(loc=0) +pl.title("Source and Target distribution") + +############################################################################### +# Spherical Sliced Wasserstein for different seeds and number of projections +# -------------------------------------------------------------------------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +############################################################################### +# Plot Spherical Sliced Wasserstein +# --------------------------------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() diff --git a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py index 931798b..8d227c0 100644 --- a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py +++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py @@ -127,7 +127,7 @@ for i, z in enumerate(zs): ys = B_l2[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = pl.gcf().add_subplot(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) poly.set_alpha(0.7) @@ -149,7 +149,7 @@ for i, z in enumerate(zs): ys = B_wass[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = pl.gcf().add_subplot(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) poly.set_alpha(0.7) |