summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py10
-rw-r--r--examples/backends/plot_ssw_unif_torch.py153
-rw-r--r--examples/barycenters/plot_barycenter_1D.py4
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py28
-rw-r--r--examples/barycenters/plot_free_support_sinkhorn_barycenter.py151
-rw-r--r--examples/barycenters/plot_generalized_free_support_barycenter.py155
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py2
-rw-r--r--examples/gromov/plot_barycenter_fgw.py2
-rw-r--r--examples/gromov/plot_gromov.py1
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py3
-rwxr-xr-xexamples/gromov/plot_gromov_wasserstein_dictionary_learning.py53
-rw-r--r--examples/gromov/plot_semirelaxed_fgw.py301
-rw-r--r--examples/others/plot_COOT.py97
-rw-r--r--examples/others/plot_learning_weights_with_COOT.py150
-rw-r--r--examples/plot_compute_wasserstein_circle.py161
-rw-r--r--examples/sliced-wasserstein/plot_variance_ssw.py111
-rw-r--r--examples/unbalanced-partial/plot_UOT_barycenter_1D.py4
17 files changed, 1344 insertions, 42 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index cf5d64d..f00de50 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -74,7 +74,7 @@ x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
loss_iter = []
# generator for random permutations
-gen = torch.Generator()
+gen = torch.Generator(device=device)
gen.manual_seed(42)
for i in range(nb_iter_max):
@@ -103,7 +103,7 @@ ax = pl.axis()
# %%
# Animate trajectories of the gradient flow along iteration
-# -------------------------------------------------------
+# ---------------------------------------------------------
pl.figure(3, (8, 4))
@@ -122,7 +122,7 @@ ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100,
# %%
# Compute the Sliced Wasserstein Barycenter
-#
+# -----------------------------------------
x1_torch = torch.tensor(x1).to(device=device)
x3_torch = torch.tensor(x3).to(device=device)
xbinit = np.random.randn(500, 2) * 10 + 16
@@ -136,7 +136,7 @@ x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
loss_iter = []
# generator for random permutations
-gen = torch.Generator()
+gen = torch.Generator(device=device)
gen.manual_seed(42)
alpha = 0.5
@@ -169,7 +169,7 @@ ax = pl.axis()
# %%
# Animate trajectories of the barycenter along gradient descent
-# -------------------------------------------------------
+# -------------------------------------------------------------
pl.figure(5, (8, 4))
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py
new file mode 100644
index 0000000..7ccc2af
--- /dev/null
+++ b/examples/backends/plot_ssw_unif_torch.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+r"""
+================================================
+Spherical Sliced-Wasserstein Embedding on Sphere
+================================================
+
+Here, we aim at transforming samples into a uniform
+distribution on the sphere by minimizing SSW:
+
+.. math::
+ \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i})
+
+where :math:`\nu=\mathrm{Unif}(S^1)`.
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import matplotlib.animation as animation
+import torch
+import torch.nn.functional as F
+
+import ot
+
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+
+N = 1000
+x0 = torch.rand(N, 3)
+x0 = F.normalize(x0, dim=-1)
+
+
+# %%
+# Plot data
+# ---------
+
+def plot_sphere(ax):
+ xlist = np.linspace(-1.0, 1.0, 50)
+ ylist = np.linspace(-1.0, 1.0, 50)
+ r = np.linspace(1.0, 1.0, 50)
+ X, Y = np.meshgrid(xlist, ylist)
+
+ Z = np.sqrt(np.maximum(r**2 - X**2 - Y**2, 0))
+
+ ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3)
+ ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half
+
+
+# plot the distributions
+pl.figure(1)
+ax = pl.axes(projection='3d')
+plot_sphere(ax)
+ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5)
+ax.set_title('Data distribution')
+ax.legend()
+
+
+# %%
+# Gradient descent
+# ----------------
+
+x = x0.clone()
+x.requires_grad_(True)
+
+n_iter = 500
+lr = 100
+
+losses = []
+xvisu = torch.zeros(n_iter, N, 3)
+
+for i in range(n_iter):
+ sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
+ grad_x = torch.autograd.grad(sw, x)[0]
+
+ x = x - lr * grad_x
+ x = F.normalize(x, p=2, dim=1)
+
+ losses.append(sw.item())
+ xvisu[i, :, :] = x.detach().clone()
+
+ if i % 100 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+pl.figure(1)
+pl.semilogy(losses)
+pl.grid()
+pl.title('SSW')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot trajectories of generated samples along iterations
+# -------------------------------------------------------
+
+ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
+
+fig = pl.figure(3, (10, 10))
+for i in range(9):
+ # pl.subplot(3, 3, i + 1)
+ # ax = pl.axes(projection='3d')
+ ax = fig.add_subplot(3, 3, i + 1, projection='3d')
+ plot_sphere(ax)
+ ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5)
+ ax.set_title('Iter. {}'.format(ivisu[i]))
+ #ax.axis("off")
+ if i == 0:
+ ax.legend()
+
+
+# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ i = 3 * i
+ pl.clf()
+ ax = pl.axes(projection='3d')
+ plot_sphere(ax)
+ ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5)
+ ax.axis("off")
+ ax.set_xlim((-1.5, 1.5))
+ ax.set_ylim((-1.5, 1.5))
+ ax.set_title('Iter. {}'.format(i))
+ return 1
+
+
+print(xvisu.shape)
+
+i = 0
+ax = pl.axes(projection='3d')
+plot_sphere(ax)
+ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ax.axis("off")
+ax.set_xlim((-1.5, 1.5))
+ax.set_ylim((-1.5, 1.5))
+ax.set_title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
+# %%
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 2373e99..8096245 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -106,7 +106,7 @@ for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = plt.gcf().gca(projection='3d')
+ax = plt.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -128,7 +128,7 @@ for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = plt.gcf().gca(projection='3d')
+ax = plt.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 226dfeb..f4a13dd 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -4,13 +4,14 @@
2D free support Wasserstein barycenters of distributions
========================================================
-Illustration of 2D Wasserstein barycenters if distributions are weighted
+Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
sum of diracs.
"""
# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
# Rémi Flamary <remi.flamary@polytechnique.edu>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -48,7 +49,7 @@ pl.title('Distributions')
# %%
-# Compute free support barycenter
+# Compute free support Wasserstein barycenter
# -------------------------------
k = 200 # number of Diracs of the barycenter
@@ -58,7 +59,28 @@ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, on
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
# %%
-# Plot the barycenter
+# Plot the Wasserstein barycenter
+# ---------
+
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc="lower right")
+pl.show()
+
+# %%
+# Compute free support Sinkhorn barycenter
+
+k = 200 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)
+
+# %%
+# Plot the Wasserstein barycenter
# ---------
pl.figure(2, (8, 3))
diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
new file mode 100644
index 0000000..ebe1f3b
--- /dev/null
+++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+"""
+========================================================
+2D free support Sinkhorn barycenters of distributions
+========================================================
+
+Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds
+
+"""
+
+# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pyplot as plt
+import ot
+
+# %%
+# General Parameters
+# ------------------
+reg = 1e-2 # Entropic Regularization
+numItermax = 20 # Maximum number of iterations for the Barycenter algorithm
+numInnerItermax = 50 # Maximum number of sinkhorn iterations
+n_samples = 200
+
+# %%
+# Generate Data
+# -------------
+
+X1 = np.random.randn(200, 2)
+X2 = 2 * np.concatenate([
+ np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1),
+ np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1),
+ np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1),
+], axis=0)
+X3 = np.random.randn(200, 2)
+X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None])
+X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200)
+
+a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1))
+
+# %%
+# Inspect generated distributions
+# -------------------------------
+
+fig, axes = plt.subplots(1, 4, figsize=(16, 4))
+
+axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k')
+axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k')
+axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k')
+axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k')
+
+axes[0].set_xlim([-3, 3])
+axes[0].set_ylim([-3, 3])
+axes[0].set_title('Distribution 1')
+
+axes[1].set_xlim([-3, 3])
+axes[1].set_ylim([-3, 3])
+axes[1].set_title('Distribution 2')
+
+axes[2].set_xlim([-3, 3])
+axes[2].set_ylim([-3, 3])
+axes[2].set_title('Distribution 3')
+
+axes[3].set_xlim([-3, 3])
+axes[3].set_ylim([-3, 3])
+axes[3].set_title('Distribution 4')
+
+plt.tight_layout()
+plt.show()
+
+# %%
+# Interpolating Empirical Distributions
+# -------------------------------------
+
+fig = plt.figure(figsize=(10, 10))
+
+weights = np.array([
+ [3 / 3, 0 / 3],
+ [2 / 3, 1 / 3],
+ [1 / 3, 2 / 3],
+ [0 / 3, 3 / 3],
+]).astype(np.float32)
+
+for k in range(4):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X2],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (0, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X1, X3],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 0))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 4, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X3, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (3, k))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+for k in range(1, 3, 1):
+ XB_init = np.random.randn(n_samples, 2)
+ XB = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations=[X2, X4],
+ measures_weights=[a1, a2],
+ weights=weights[k],
+ X_init=XB_init,
+ reg=reg,
+ numItermax=numItermax,
+ numInnerItermax=numInnerItermax
+ )
+ ax = plt.subplot2grid((4, 4), (k, 3))
+ ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
+ ax.set_xlim([-3, 3])
+ ax.set_ylim([-3, 3])
+
+plt.show()
diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py
new file mode 100644
index 0000000..e685ec7
--- /dev/null
+++ b/examples/barycenters/plot_generalized_free_support_barycenter.py
@@ -0,0 +1,155 @@
+# -*- coding: utf-8 -*-
+"""
+=======================================
+Generalized Wasserstein Barycenter Demo
+=======================================
+
+This example illustrates the computation of Generalized Wasserstein Barycenter
+as proposed in [42].
+
+
+[42] Delon, J., Gozlan, N., and Saint-Dizier, A..
+Generalized Wasserstein barycenters between probability measures living on different subspaces.
+arXiv preprint arXiv:2105.09755, 2021.
+
+"""
+
+# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.pylab as pl
+import ot
+import matplotlib.animation as animation
+
+########################
+# Generate and plot data
+# ----------------------
+
+# Input measures
+sub_sample_factor = 8
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
+I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
+
+sz = I1.shape[0]
+UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))
+
+# Input measure locations in their respective 2D spaces
+X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]]
+
+# Input measure weights
+a_list = [ot.unif(x.shape[0]) for x in X_list]
+
+# Projections 3D -> 2D
+P1 = np.array([[1, 0, 0], [0, 1, 0]])
+P2 = np.array([[0, 1, 0], [0, 0, 1]])
+P3 = np.array([[1, 0, 0], [0, 0, 1]])
+P_list = [P1, P2, P3]
+
+# Barycenter weights
+weights = np.array([1 / 3, 1 / 3, 1 / 3])
+
+# Number of barycenter points to compute
+n_samples_bary = 150
+
+# Send the input measures into 3D space for visualisation
+X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)]
+
+# Plot the input data
+fig = plt.figure(figsize=(3, 3))
+axis = fig.add_subplot(1, 1, 1, projection="3d")
+for Xi in X_visu:
+ axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+axis.view_init(azim=45)
+axis.set_xticks([])
+axis.set_yticks([])
+axis.set_zticks([])
+plt.show()
+
+#################################
+# Barycenter computation and plot
+# -------------------------------
+
+Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary)
+fig = plt.figure(figsize=(3, 3))
+
+axis = fig.add_subplot(1, 1, 1, projection="3d")
+for Xi in X_visu:
+ axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+axis.view_init(azim=45)
+axis.set_xticks([])
+axis.set_yticks([])
+axis.set_zticks([])
+plt.show()
+
+
+#############################
+# Plotting projection matches
+# ---------------------------
+
+fig = plt.figure(figsize=(9, 3))
+
+ax = fig.add_subplot(1, 3, 1, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=0, azim=0)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+ax = fig.add_subplot(1, 3, 2, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=0, azim=90)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+ax = fig.add_subplot(1, 3, 3, projection='3d')
+for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ax.view_init(elev=90, azim=0)
+ax.set_xticks([])
+ax.set_yticks([])
+ax.set_zticks([])
+
+plt.tight_layout()
+plt.show()
+
+##############################################
+# Rotation animation
+# --------------------------------------------
+
+fig = plt.figure(figsize=(7, 7))
+ax = fig.add_subplot(1, 1, 1, projection="3d")
+
+
+def _init():
+ for Xi in X_visu:
+ ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
+ ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
+ ax.view_init(elev=0, azim=0)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_zticks([])
+ return fig,
+
+
+def _update_plot(i):
+ if i < 45:
+ ax.view_init(elev=0, azim=4 * i)
+ else:
+ ax.view_init(elev=i - 45, azim=4 * i)
+ return fig,
+
+
+ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000)
diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index a44096a..8284a2a 100644
--- a/examples/domain-adaptation/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -61,7 +61,7 @@ plt.plot(xt[:, 0], xt[:, 1], 'o')
# Estimate linear mapping and transport
# -------------------------------------
-Ae, be = ot.da.OT_mapping_linear(xs, xt)
+Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt)
xst = xs.dot(Ae) + be
diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py
index 556e08f..dc3c6aa 100644
--- a/examples/gromov/plot_barycenter_fgw.py
+++ b/examples/gromov/plot_barycenter_fgw.py
@@ -174,7 +174,7 @@ A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)
# -------------------------
#%% Create the barycenter
-bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+bary = nx.from_numpy_array(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
for i, v in enumerate(A.ravel()):
bary.add_node(i, attr_name=v)
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index 5a362cf..05074dc 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -3,7 +3,6 @@
==========================
Gromov-Wasserstein example
==========================
-
This example is designed to show how to use the Gromov-Wassertsein distance
computation in POT.
"""
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index 7fe081f..08ec610 100755
--- a/examples/gromov/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -110,8 +110,7 @@ for nb in range(4):
if shapes[nb][i, j] < 0.95:
xs[nb].append([j, 8 - i])
-xs = np.array([np.array(xs[0]), np.array(xs[1]),
- np.array(xs[2]), np.array(xs[3])])
+xs = [np.array(xs[s]) for s in range(S)]
##############################################################################
# Barycenter computation
diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
index 1fdc3b9..7585944 100755
--- a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
+++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
@@ -45,10 +45,11 @@ from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dic
import ot
import networkx
from networkx.generators.community import stochastic_block_model as sbm
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
-# =============================================================================
+# ---------------------------------------------
np.random.seed(42)
@@ -109,10 +110,10 @@ for idx_c, c in enumerate(clusters):
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Estimate the gromov-wasserstein dictionary from the dataset
-# =============================================================================
+# ---------------------------------------------
np.random.seed(0)
@@ -140,10 +141,10 @@ pl.ylabel('loss', fontsize=12)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the estimated dictionary atoms
-# =============================================================================
+# ---------------------------------------------
# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
@@ -164,10 +165,11 @@ for idx_atom, atom in enumerate(Cdict_GW):
pl.axis("off")
pl.tight_layout()
pl.show()
-#%%
-# =============================================================================
+
+#############################################################################
+#
# Visualization of the embedding space
-# =============================================================================
+# ---------------------------------------------
unmixings = []
reconstruction_errors = []
@@ -211,11 +213,11 @@ pl.axis('off')
pl.legend(fontsize=11)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
-# Endow the dataset with node features
-# =============================================================================
+#############################################################################
+#
+# Endow the dataset with node features
+# ---------------------------------------------
# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
# 1 cluster --> 0 as nodes feature
# 2 clusters --> 1 as nodes feature
@@ -251,10 +253,11 @@ for idx_c, c in enumerate(clusters):
pl.axis("off")
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
-# =============================================================================
+# ---------------------------------------------
np.random.seed(0)
ps = [ot.unif(C.shape[0]) for C in dataset]
D = 3 # 6 atoms instead of 3
@@ -280,10 +283,10 @@ pl.ylabel('loss', fontsize=12)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the estimated dictionary atoms
-# =============================================================================
+# ---------------------------------------------
pl.figure(7, (12, 8))
pl.clf()
@@ -307,10 +310,10 @@ for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the embedding space
-# =============================================================================
+# ---------------------------------------------
unmixings = []
reconstruction_errors = []
diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py
new file mode 100644
index 0000000..ef4b286
--- /dev/null
+++ b/examples/gromov/plot_semirelaxed_fgw.py
@@ -0,0 +1,301 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Semi-relaxed (Fused) Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the semi-relaxed Gromov-Wasserstein
+and the semi-relaxed Fused Gromov-Wasserstein divergences.
+
+sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of
+G2 at a minimal (F)GW distance from G1.
+
+First, we generate two graphs following Stochastic Block Models, then show
+how to compute their srGW matchings and illustrate them. These graphs are then
+endowed with node features and we follow the same process with srFGW.
+
+[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+International Conference on Learning Representations (ICLR), 2021.
+"""
+
+# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+import numpy as np
+import matplotlib.pylab as pl
+from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+
+#############################################################################
+#
+# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
+# ---------------------------------------------
+
+
+N2 = 20 # 2 communities
+N3 = 30 # 3 communities
+p2 = [[1., 0.1],
+ [0.1, 0.9]]
+p3 = [[1., 0.1, 0.],
+ [0.1, 0.95, 0.1],
+ [0., 0.1, 0.9]]
+G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
+G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)
+
+
+C2 = networkx.to_numpy_array(G2)
+C3 = networkx.to_numpy_array(G3)
+
+h2 = np.ones(C2.shape[0]) / C2.shape[0]
+h3 = np.ones(C3.shape[0]) / C3.shape[0]
+
+# Add weights on the edges for visualization later on
+weight_intra_G2 = 5
+weight_inter_G2 = 0.5
+weight_intra_G3 = 1.
+weight_inter_G3 = 1.5
+
+weightedG2 = networkx.Graph()
+part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
+
+for node in G2.nodes():
+ weightedG2.add_node(node)
+for i, j in G2.edges():
+ if part_G2[i] == part_G2[j]:
+ weightedG2.add_edge(i, j, weight=weight_intra_G2)
+ else:
+ weightedG2.add_edge(i, j, weight=weight_inter_G2)
+
+weightedG3 = networkx.Graph()
+part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
+
+for node in G3.nodes():
+ weightedG3.add_node(node)
+for i, j in G3.edges():
+ if part_G3[i] == part_G3[j]:
+ weightedG3.add_edge(i, j, weight=weight_intra_G3)
+ else:
+ weightedG3.add_edge(i, j, weight=weight_inter_G3)
+
+#############################################################################
+#
+# Compute their semi-relaxed Gromov-Wasserstein divergences
+# ---------------------------------------------
+
+# 0) GW(C2, h2, C3, h3) for reference
+OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
+gw = log['gw_dist']
+
+# 1) srGW(C2, h2, C3)
+OT_23, log_23 = semirelaxed_gromov_wasserstein(C2, C3, h2, symmetric=True,
+ log=True, G0=None)
+srgw_23 = log_23['srgw_dist']
+
+# 2) srGW(C3, h3, C2)
+
+OT_32, log_32 = semirelaxed_gromov_wasserstein(C3, C2, h3, symmetric=None,
+ log=True, G0=OT.T)
+srgw_32 = log_32['srgw_dist']
+
+print('GW(C2, C3) = ', gw)
+print('srGW(C2, h2, C3) = ', srgw_23)
+print('srGW(C3, h3, C2) = ', srgw_32)
+
+
+#############################################################################
+#
+# Visualization of the semi-relaxed Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the srGW matching
+
+
+def draw_graph(G, C, nodes_color_part, Gweights=None,
+ pos=None, edge_color='black', node_size=None,
+ shiftx=0, seed=0):
+
+ if (pos is None):
+ pos = networkx.spring_layout(G, scale=1., seed=seed)
+
+ if shiftx != 0:
+ for k, v in pos.items():
+ v[0] = v[0] + shiftx
+
+ alpha_edge = 0.7
+ width_edge = 1.8
+ if Gweights is None:
+ networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
+ else:
+ # We make more visible connections between activated nodes
+ n = len(Gweights)
+ edgelist_activated = []
+ edgelist_deactivated = []
+ for i in range(n):
+ for j in range(n):
+ if Gweights[i] * Gweights[j] * C[i, j] > 0:
+ edgelist_activated.append((i, j))
+ elif C[i, j] > 0:
+ edgelist_deactivated.append((i, j))
+
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated,
+ width=width_edge, alpha=alpha_edge,
+ edge_color=edge_color)
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated,
+ width=width_edge, alpha=0.1,
+ edge_color=edge_color)
+
+ if Gweights is None:
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=node_size, alpha=1,
+ node_color=node_color)
+ else:
+ scaled_Gweights = Gweights / (0.5 * Gweights.max())
+ nodes_size = node_size * scaled_Gweights
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=nodes_size[node], alpha=1,
+ node_color=node_color)
+ return pos
+
+
+def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
+ p1, p2, T, pos1=None, pos2=None,
+ shiftx=4, switchx=False, node_size=70,
+ seed_G1=0, seed_G2=0):
+ starting_color = 0
+ # get graphs partition and their coloring
+ part1 = part_G1.copy()
+ unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)]
+ nodes_color_part1 = []
+ for cluster in part1:
+ nodes_color_part1.append(unique_colors[cluster])
+
+ nodes_color_part2 = []
+ # T: getting colors assignment from argmin of columns
+ for i in range(len(G2.nodes())):
+ j = np.argmax(T[:, i])
+ nodes_color_part2.append(nodes_color_part1[j])
+ pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1,
+ pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1)
+ pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
+ node_size=node_size, shiftx=shiftx, seed=seed_G2)
+ for k1, v1 in pos1.items():
+ for k2, v2 in pos2.items():
+ if (T[k1, k2] > 0):
+ pl.plot([pos1[k1][0], pos2[k2][0]],
+ [pos1[k1][1], pos2[k2][1]],
+ '-', lw=0.8, alpha=0.5,
+ color=nodes_color_part1[k1])
+ return pos1, pos2
+
+
+node_size = 40
+fontsize = 10
+seed_G2 = 0
+seed_G3 = 4
+
+pl.figure(1, figsize=(8, 2.5))
+pl.clf()
+pl.subplot(121)
+pl.axis('off')
+pl.axis
+pl.title(r'srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize)
+
+hbar2 = OT_23.sum(axis=0)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+pl.subplot(122)
+pl.axis('off')
+hbar3 = OT_32.sum(axis=0)
+pl.title(r'srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
+ pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
+pl.tight_layout()
+
+pl.show()
+
+#############################################################################
+#
+# Add node features
+# ---------------------------------------------
+
+# We add node features with given mean - by clusters
+# and inversely proportional to clusters' intra-connectivity
+
+F2 = np.zeros((N2, 1))
+for i, c in enumerate(part_G2):
+ F2[i, 0] = np.random.normal(loc=c, scale=0.01)
+
+F3 = np.zeros((N3, 1))
+for i, c in enumerate(part_G3):
+ F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
+
+#############################################################################
+#
+# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
+# ---------------------------------------------
+
+alpha = 0.5
+# Compute pairwise euclidean distance between node features
+M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T)
+
+# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference
+
+OT, log = fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True)
+fgw = log['fgw_dist']
+
+# 1) srFGW(C2, F2, h2, C3, F3)
+OT_23, log_23 = semirelaxed_fused_gromov_wasserstein(
+ M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None)
+srfgw_23 = log_23['srfgw_dist']
+
+# 2) srFGW(C3, F3, h3, C2, F2)
+
+OT_32, log_32 = semirelaxed_fused_gromov_wasserstein(
+ M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None)
+srfgw_32 = log_32['srfgw_dist']
+
+print('FGW(C2, F2, C3, F3) = ', fgw)
+print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23)
+print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32)
+
+#############################################################################
+#
+# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the srFGW matching
+# NB: colors refer to clusters - not to node features
+
+pl.figure(2, figsize=(8, 2.5))
+pl.clf()
+pl.subplot(121)
+pl.axis('off')
+pl.axis
+pl.title(r'srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize)
+
+hbar2 = OT_23.sum(axis=0)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+pl.subplot(122)
+pl.axis('off')
+hbar3 = OT_32.sum(axis=0)
+pl.title(r'srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
+ pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/others/plot_COOT.py b/examples/others/plot_COOT.py
new file mode 100644
index 0000000..98c1ce1
--- /dev/null
+++ b/examples/others/plot_COOT.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+r"""
+===================================================
+Row and column alignments with CO-Optimal Transport
+===================================================
+
+This example is designed to show how to use the CO-Optimal Transport [47]_ in POT.
+CO-Optimal Transport allows to calculate the distance between two **arbitrary-size**
+matrices, and to align their rows and columns. In this example, we consider two
+random matrices :math:`X_1` and :math:`X_2` defined by
+:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)`
+and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`.
+
+.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
+ `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
+ Advances in Neural Information Processing Systems, 33.
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# License: MIT License
+
+from matplotlib.patches import ConnectionPatch
+import matplotlib.pylab as pl
+import numpy as np
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+
+# %%
+# Generating two random matrices
+
+n1 = 20
+n2 = 10
+d1 = 16
+d2 = 8
+sigma = 0.2
+
+X1 = (
+ np.cos(np.arange(n1) * np.pi / n1)[:, None] +
+ np.cos(np.arange(d1) * np.pi / d1)[None, :] +
+ sigma * np.random.randn(n1, d1)
+)
+X2 = (
+ np.cos(np.arange(n2) * np.pi / n2)[:, None] +
+ np.cos(np.arange(d2) * np.pi / d2)[None, :] +
+ sigma * np.random.randn(n2, d2)
+)
+
+# %%
+# Visualizing the matrices
+
+pl.figure(1, (8, 5))
+pl.subplot(1, 2, 1)
+pl.imshow(X1)
+pl.title('$X_1$')
+
+pl.subplot(1, 2, 2)
+pl.imshow(X2)
+pl.title("$X_2$")
+
+pl.tight_layout()
+
+# %%
+# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance
+
+pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True)
+coot_distance = coot2(X1, X2)
+print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance))
+
+fig = pl.figure(4, (9, 7))
+pl.clf()
+
+ax1 = pl.subplot(2, 2, 3)
+pl.imshow(X1)
+pl.xlabel('$X_1$')
+
+ax2 = pl.subplot(2, 2, 2)
+ax2.yaxis.tick_right()
+pl.imshow(np.transpose(X2))
+pl.title("Transpose($X_2$)")
+ax2.xaxis.tick_top()
+
+for i in range(n1):
+ j = np.argmax(pi_sample[i, :])
+ xyA = (d1 - .5, i)
+ xyB = (j, d2 - .5)
+ con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
+ coordsB=ax2.transData, color="black")
+ fig.add_artist(con)
+
+for i in range(d1):
+ j = np.argmax(pi_feature[i, :])
+ xyA = (i, -.5)
+ xyB = (-.5, j)
+ con = ConnectionPatch(
+ xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
+ fig.add_artist(con)
diff --git a/examples/others/plot_learning_weights_with_COOT.py b/examples/others/plot_learning_weights_with_COOT.py
new file mode 100644
index 0000000..cb115c3
--- /dev/null
+++ b/examples/others/plot_learning_weights_with_COOT.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+r"""
+===============================================================
+Learning sample marginal distribution with CO-Optimal Transport
+===============================================================
+
+In this example, we illustrate how to estimate the sample marginal distribution which minimizes
+the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
+:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
+histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
+
+.. math::
+ \min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)
+
+where :math:`\Delta` is the probability simplex. This minimization is done with a
+simple projected gradient descent in PyTorch. We use the automatic backend of POT that
+allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
+with differentiable losses.
+
+.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
+ `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
+ Advances in Neural Information Processing Systems, 33.
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# License: MIT License
+
+from matplotlib.patches import ConnectionPatch
+import torch
+import numpy as np
+
+import matplotlib.pyplot as pl
+import ot
+
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+
+
+# %%
+# Generate data
+# -------------
+# The source and clean target matrices are generated by
+# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and
+# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`.
+# The target matrix is then contaminated by adding 5 row outliers.
+# Intuitively, we expect that the estimated sample distribution should ignore these outliers,
+# i.e. their weights should be zero.
+
+np.random.seed(182)
+
+n1, d1 = 20, 16
+n2, d2 = 10, 8
+n = 15
+
+X = (
+ torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] +
+ torch.cos(torch.arange(d1) * torch.pi / d1)[None, :]
+)
+
+# Generate clean target data mixed with outliers
+Y_noisy = torch.randn((n, d2)) * 10.0
+Y_noisy[:n2, :] = (
+ torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] +
+ torch.cos(torch.arange(d2) * torch.pi / d2)[None, :]
+)
+Y = Y_noisy[:n2, :]
+
+X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double()
+
+fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5))
+axes[0].imshow(X, vmin=-2, vmax=2)
+axes[0].set_title('$X$')
+
+axes[1].imshow(Y, vmin=-2, vmax=2)
+axes[1].set_title('Clean $Y$')
+
+axes[2].imshow(Y_noisy, vmin=-2, vmax=2)
+axes[2].set_title('Noisy $Y$')
+
+pl.tight_layout()
+
+# %%
+# Optimize the COOT distance with respect to the sample marginal distribution
+# ---------------------------------------------------------------------------
+
+losses = []
+lr = 1e-3
+niter = 1000
+
+b = torch.tensor(ot.unif(n), requires_grad=True)
+
+for i in range(niter):
+
+ loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ b -= lr * b.grad # gradient step
+ b[:] = ot.utils.proj_simplex(b) # projection on the simplex
+
+ b.grad.zero_()
+
+# Estimated sample marginal distribution and training loss curve
+pl.plot(losses[10:])
+pl.title('CO-Optimal Transport distance')
+
+print(f"Marginal distribution = {b.detach().numpy()}")
+
+# %%
+# Visualizing the row and column alignments with the estimated sample marginal distribution
+# -----------------------------------------------------------------------------------------
+#
+# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers.
+
+X, Y_noisy = X.numpy(), Y_noisy.numpy()
+b = b.detach().numpy()
+
+pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True)
+
+fig = pl.figure(4, (9, 7))
+pl.clf()
+
+ax1 = pl.subplot(2, 2, 3)
+pl.imshow(X, vmin=-2, vmax=2)
+pl.xlabel('$X$')
+
+ax2 = pl.subplot(2, 2, 2)
+ax2.yaxis.tick_right()
+pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
+pl.title("Transpose(Noisy $Y$)")
+ax2.xaxis.tick_top()
+
+for i in range(n1):
+ j = np.argmax(pi_sample[i, :])
+ xyA = (d1 - .5, i)
+ xyB = (j, d2 - .5)
+ con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
+ coordsB=ax2.transData, color="black")
+ fig.add_artist(con)
+
+for i in range(d1):
+ j = np.argmax(pi_feature[i, :])
+ xyA = (i, -.5)
+ xyB = (-.5, j)
+ con = ConnectionPatch(
+ xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
+ fig.add_artist(con)
diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py
new file mode 100644
index 0000000..3ede96f
--- /dev/null
+++ b/examples/plot_compute_wasserstein_circle.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+"""
+=========================
+OT distance on the Circle
+=========================
+
+Shows how to compute the Wasserstein distance on the circle
+
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+from scipy.special import iv
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+
+def pdf_von_Mises(theta, mu, kappa):
+ pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa))
+ return pdf
+
+
+t = np.linspace(0, 2 * np.pi, 1000, endpoint=False)
+
+mu1 = 1
+kappa1 = 20
+
+mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10)
+
+
+pdf1 = pdf_von_Mises(t, mu1, kappa1)
+
+
+pl.figure(1)
+for k, mu in enumerate(mu_targets):
+ pdf_t = pdf_von_Mises(t, mu, kappa1)
+ if k == 0:
+ label = "Source distributions"
+ else:
+ label = None
+ pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label)
+
+pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution")
+pl.legend()
+
+mu2 = 0
+kappa2 = kappa1
+
+x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi
+x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi
+
+angles = np.linspace(0, 2 * np.pi, 150)
+
+pl.figure(2)
+pl.plot(np.cos(angles), np.sin(angles), c="k")
+pl.xlim(-1.25, 1.25)
+pl.ylim(-1.25, 1.25)
+pl.scatter(np.cos(x1), np.sin(x1), c="b")
+pl.scatter(np.cos(x2), np.sin(x2), c="r")
+
+#########################################################################################
+# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle
+# ---------------------------------------------------------------------------------------
+# This examples illustrates the periodicity of the Wasserstein distance on the circle.
+# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}`
+# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution
+# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`.
+# The Wasserstein distance on the circle takes into account the periodicity
+# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the
+# Euclidean version.
+
+#%% Compute and plot distributions
+
+mu_targets = np.linspace(0, 2 * np.pi, 200)
+xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi
+
+n_try = 5
+
+xts = np.zeros((n_try, 200, 500))
+for i in range(n_try):
+ for k, mu in enumerate(mu_targets):
+ # np.random.vonmises deals with data on [-pi, pi[
+ xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi
+ xts[i, k] = xt
+
+# Put data on S^1=[0,1[
+xts2 = xts / (2 * np.pi)
+xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi)
+
+L_w2_circle = np.zeros((n_try, 200))
+L_w2 = np.zeros((n_try, 200))
+
+for i in range(n_try):
+ w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2)
+ w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2)
+
+ L_w2_circle[i] = w2_circle
+ L_w2[i] = w2
+
+m_w2_circle = np.mean(L_w2_circle, axis=0)
+std_w2_circle = np.std(L_w2_circle, axis=0)
+
+m_w2 = np.mean(L_w2, axis=0)
+std_w2 = np.std(L_w2, axis=0)
+
+pl.figure(1)
+pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle")
+pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5)
+pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein")
+pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5)
+pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$")
+pl.legend()
+pl.xlabel(r"$\mu_{\mathrm{source}}$")
+pl.show()
+
+
+########################################################################
+# Wasserstein distance between von Mises and uniform for different kappa
+# ----------------------------------------------------------------------
+# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`.
+
+#%% Compute Wasserstein between Von Mises and uniform
+
+kappas = np.logspace(-5, 2, 100)
+n_try = 20
+
+xts = np.zeros((n_try, 100, 500))
+for i in range(n_try):
+ for k, kappa in enumerate(kappas):
+ # np.random.vonmises deals with data on [-pi, pi[
+ xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi
+ xts[i, k] = xt / (2 * np.pi)
+
+L_w2 = np.zeros((n_try, 100))
+for i in range(n_try):
+ L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T)
+
+m_w2 = np.mean(L_w2, axis=0)
+std_w2 = np.std(L_w2, axis=0)
+
+pl.figure(1)
+pl.plot(kappas, m_w2)
+pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5)
+pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$")
+pl.xlabel(r"$\kappa$")
+pl.show()
+
+# %%
diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py
new file mode 100644
index 0000000..83d458f
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance_ssw.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+Spherical Sliced Wasserstein on distributions in S^2
+====================================================
+
+This example illustrates the computation of the spherical sliced Wasserstein discrepancy as
+proposed in [46].
+
+[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations.
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+xs = np.random.randn(n, 3)
+xt = np.random.randn(n, 3)
+
+xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True))
+xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True))
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+fig = pl.figure(figsize=(10, 10))
+ax = pl.axes(projection='3d')
+ax.grid(False)
+
+u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j]
+x = np.cos(u) * np.sin(v)
+y = np.sin(u) * np.sin(v)
+z = np.cos(v)
+ax.plot_surface(x, y, z, color="gray", alpha=0.03)
+ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray")
+
+ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source")
+ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target")
+
+fs = 10
+# Labels
+ax.set_xlabel('x', fontsize=fs)
+ax.set_ylabel('y', fontsize=fs)
+ax.set_zlabel('z', fontsize=fs)
+
+ax.view_init(20, 120)
+ax.set_xlim(-1.5, 1.5)
+ax.set_ylim(-1.5, 1.5)
+ax.set_zlim(-1.5, 1.5)
+
+# Ticks
+ax.set_xticks([-1, 0, 1])
+ax.set_yticks([-1, 0, 1])
+ax.set_zticks([-1, 0, 1])
+
+pl.legend(loc=0)
+pl.title("Source and Target distribution")
+
+###############################################################################
+# Spherical Sliced Wasserstein for different seeds and number of projections
+# --------------------------------------------------------------------------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###############################################################################
+# Plot Spherical Sliced Wasserstein
+# ---------------------------------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()
diff --git a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
index 931798b..8d227c0 100644
--- a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
@@ -127,7 +127,7 @@ for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = pl.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
poly.set_alpha(0.7)
@@ -149,7 +149,7 @@ for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = pl.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
poly.set_alpha(0.7)