summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-02-23 08:31:01 +0100
committerGitHub <noreply@github.com>2023-02-23 08:31:01 +0100
commit80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (patch)
treee4c2e938896243842e290d8fcf78879a8f6960bf /examples
parent97feeb32b6c069d7bb44cd995531c2b820d59771 (diff)
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
Diffstat (limited to 'examples')
-rw-r--r--examples/backends/plot_ssw_unif_torch.py153
-rw-r--r--examples/plot_compute_wasserstein_circle.py161
-rw-r--r--examples/sliced-wasserstein/plot_variance_ssw.py111
3 files changed, 425 insertions, 0 deletions
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py
new file mode 100644
index 0000000..d1de5a9
--- /dev/null
+++ b/examples/backends/plot_ssw_unif_torch.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+r"""
+================================================
+Spherical Sliced-Wasserstein Embedding on Sphere
+================================================
+
+Here, we aim at transforming samples into a uniform
+distribution on the sphere by minimizing SSW:
+
+.. math::
+ \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i})
+
+where :math:`\nu=\mathrm{Unif}(S^1)`.
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import matplotlib.animation as animation
+import torch
+import torch.nn.functional as F
+
+import ot
+
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+
+N = 1000
+x0 = torch.rand(N, 3)
+x0 = F.normalize(x0, dim=-1)
+
+
+# %%
+# Plot data
+# ---------
+
+def plot_sphere(ax):
+ xlist = np.linspace(-1.0, 1.0, 50)
+ ylist = np.linspace(-1.0, 1.0, 50)
+ r = np.linspace(1.0, 1.0, 50)
+ X, Y = np.meshgrid(xlist, ylist)
+
+ Z = np.sqrt(r**2 - X**2 - Y**2)
+
+ ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3)
+ ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half
+
+
+# plot the distributions
+pl.figure(1)
+ax = pl.axes(projection='3d')
+plot_sphere(ax)
+ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5)
+ax.set_title('Data distribution')
+ax.legend()
+
+
+# %%
+# Gradient descent
+# ----------------
+
+x = x0.clone()
+x.requires_grad_(True)
+
+n_iter = 500
+lr = 100
+
+losses = []
+xvisu = torch.zeros(n_iter, N, 3)
+
+for i in range(n_iter):
+ sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
+ grad_x = torch.autograd.grad(sw, x)[0]
+
+ x = x - lr * grad_x
+ x = F.normalize(x, p=2, dim=1)
+
+ losses.append(sw.item())
+ xvisu[i, :, :] = x.detach().clone()
+
+ if i % 100 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+pl.figure(1)
+pl.semilogy(losses)
+pl.grid()
+pl.title('SSW')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot trajectories of generated samples along iterations
+# -------------------------------------------------------
+
+ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
+
+fig = pl.figure(3, (10, 10))
+for i in range(9):
+ # pl.subplot(3, 3, i + 1)
+ # ax = pl.axes(projection='3d')
+ ax = fig.add_subplot(3, 3, i + 1, projection='3d')
+ plot_sphere(ax)
+ ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5)
+ ax.set_title('Iter. {}'.format(ivisu[i]))
+ #ax.axis("off")
+ if i == 0:
+ ax.legend()
+
+
+# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ i = 3 * i
+ pl.clf()
+ ax = pl.axes(projection='3d')
+ plot_sphere(ax)
+ ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5)
+ ax.axis("off")
+ ax.set_xlim((-1.5, 1.5))
+ ax.set_ylim((-1.5, 1.5))
+ ax.set_title('Iter. {}'.format(i))
+ return 1
+
+
+print(xvisu.shape)
+
+i = 0
+ax = pl.axes(projection='3d')
+plot_sphere(ax)
+ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ax.axis("off")
+ax.set_xlim((-1.5, 1.5))
+ax.set_ylim((-1.5, 1.5))
+ax.set_title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
+# %%
diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py
new file mode 100644
index 0000000..3ede96f
--- /dev/null
+++ b/examples/plot_compute_wasserstein_circle.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+"""
+=========================
+OT distance on the Circle
+=========================
+
+Shows how to compute the Wasserstein distance on the circle
+
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+from scipy.special import iv
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+
+def pdf_von_Mises(theta, mu, kappa):
+ pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa))
+ return pdf
+
+
+t = np.linspace(0, 2 * np.pi, 1000, endpoint=False)
+
+mu1 = 1
+kappa1 = 20
+
+mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10)
+
+
+pdf1 = pdf_von_Mises(t, mu1, kappa1)
+
+
+pl.figure(1)
+for k, mu in enumerate(mu_targets):
+ pdf_t = pdf_von_Mises(t, mu, kappa1)
+ if k == 0:
+ label = "Source distributions"
+ else:
+ label = None
+ pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label)
+
+pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution")
+pl.legend()
+
+mu2 = 0
+kappa2 = kappa1
+
+x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi
+x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi
+
+angles = np.linspace(0, 2 * np.pi, 150)
+
+pl.figure(2)
+pl.plot(np.cos(angles), np.sin(angles), c="k")
+pl.xlim(-1.25, 1.25)
+pl.ylim(-1.25, 1.25)
+pl.scatter(np.cos(x1), np.sin(x1), c="b")
+pl.scatter(np.cos(x2), np.sin(x2), c="r")
+
+#########################################################################################
+# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle
+# ---------------------------------------------------------------------------------------
+# This examples illustrates the periodicity of the Wasserstein distance on the circle.
+# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}`
+# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution
+# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`.
+# The Wasserstein distance on the circle takes into account the periodicity
+# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the
+# Euclidean version.
+
+#%% Compute and plot distributions
+
+mu_targets = np.linspace(0, 2 * np.pi, 200)
+xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi
+
+n_try = 5
+
+xts = np.zeros((n_try, 200, 500))
+for i in range(n_try):
+ for k, mu in enumerate(mu_targets):
+ # np.random.vonmises deals with data on [-pi, pi[
+ xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi
+ xts[i, k] = xt
+
+# Put data on S^1=[0,1[
+xts2 = xts / (2 * np.pi)
+xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi)
+
+L_w2_circle = np.zeros((n_try, 200))
+L_w2 = np.zeros((n_try, 200))
+
+for i in range(n_try):
+ w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2)
+ w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2)
+
+ L_w2_circle[i] = w2_circle
+ L_w2[i] = w2
+
+m_w2_circle = np.mean(L_w2_circle, axis=0)
+std_w2_circle = np.std(L_w2_circle, axis=0)
+
+m_w2 = np.mean(L_w2, axis=0)
+std_w2 = np.std(L_w2, axis=0)
+
+pl.figure(1)
+pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle")
+pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5)
+pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein")
+pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5)
+pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$")
+pl.legend()
+pl.xlabel(r"$\mu_{\mathrm{source}}$")
+pl.show()
+
+
+########################################################################
+# Wasserstein distance between von Mises and uniform for different kappa
+# ----------------------------------------------------------------------
+# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`.
+
+#%% Compute Wasserstein between Von Mises and uniform
+
+kappas = np.logspace(-5, 2, 100)
+n_try = 20
+
+xts = np.zeros((n_try, 100, 500))
+for i in range(n_try):
+ for k, kappa in enumerate(kappas):
+ # np.random.vonmises deals with data on [-pi, pi[
+ xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi
+ xts[i, k] = xt / (2 * np.pi)
+
+L_w2 = np.zeros((n_try, 100))
+for i in range(n_try):
+ L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T)
+
+m_w2 = np.mean(L_w2, axis=0)
+std_w2 = np.std(L_w2, axis=0)
+
+pl.figure(1)
+pl.plot(kappas, m_w2)
+pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5)
+pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$")
+pl.xlabel(r"$\kappa$")
+pl.show()
+
+# %%
diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py
new file mode 100644
index 0000000..83d458f
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance_ssw.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+Spherical Sliced Wasserstein on distributions in S^2
+====================================================
+
+This example illustrates the computation of the spherical sliced Wasserstein discrepancy as
+proposed in [46].
+
+[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations.
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+xs = np.random.randn(n, 3)
+xt = np.random.randn(n, 3)
+
+xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True))
+xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True))
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+fig = pl.figure(figsize=(10, 10))
+ax = pl.axes(projection='3d')
+ax.grid(False)
+
+u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j]
+x = np.cos(u) * np.sin(v)
+y = np.sin(u) * np.sin(v)
+z = np.cos(v)
+ax.plot_surface(x, y, z, color="gray", alpha=0.03)
+ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray")
+
+ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source")
+ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target")
+
+fs = 10
+# Labels
+ax.set_xlabel('x', fontsize=fs)
+ax.set_ylabel('y', fontsize=fs)
+ax.set_zlabel('z', fontsize=fs)
+
+ax.view_init(20, 120)
+ax.set_xlim(-1.5, 1.5)
+ax.set_ylim(-1.5, 1.5)
+ax.set_zlim(-1.5, 1.5)
+
+# Ticks
+ax.set_xticks([-1, 0, 1])
+ax.set_yticks([-1, 0, 1])
+ax.set_zticks([-1, 0, 1])
+
+pl.legend(loc=0)
+pl.title("Source and Target distribution")
+
+###############################################################################
+# Spherical Sliced Wasserstein for different seeds and number of projections
+# --------------------------------------------------------------------------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###############################################################################
+# Plot Spherical Sliced Wasserstein
+# ---------------------------------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()