summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py185
-rw-r--r--examples/backends/plot_wass1d_torch.py152
-rw-r--r--examples/sliced-wasserstein/plot_variance.py2
3 files changed, 338 insertions, 1 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
new file mode 100644
index 0000000..05b9952
--- /dev/null
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -0,0 +1,185 @@
+r"""
+=================================
+Sliced Wasserstein barycenter and gradient flow with PyTorch
+=================================
+
+In this exemple we use the pytorch backend to optimize the sliced Wasserstein
+loss between two empirical distributions [31].
+
+In the first example one we perform a
+gradient flow on the support of a distribution that minimize the sliced
+Wassersein distance as poposed in [36].
+
+In the second exemple we optimize with a gradient descent the sliced
+Wasserstein barycenter between two distributions as in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
+measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling
+via optimal transport and diffusions. In International Conference on
+Machine Learning (pp. 4104-4113). PMLR.
+
+
+"""
+# Author: Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+# %%
+# Loading the data
+
+
+import numpy as np
+import matplotlib.pylab as pl
+import torch
+import ot
+import matplotlib.animation as animation
+
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
+
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
+
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
+
+pl.figure(1, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+
+# %%
+# Sliced Wasserstein gradient flow with Pytorch
+# ---------------------------------------------
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True)
+x2_torch = torch.tensor(x2).to(device=device)
+
+
+lr = 1e3
+nb_iter_max = 100
+
+x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
+
+loss_iter = []
+
+# generator for random permutations
+gen = torch.Generator()
+gen.manual_seed(42)
+
+for i in range(nb_iter_max):
+
+ loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = x1_torch.grad
+ x1_torch -= grad * lr / (1 + i / 5e1) # step
+ x1_torch.grad.zero_()
+ x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy()
+
+xb = x1_torch.clone().detach().cpu().numpy()
+
+pl.figure(2, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$')
+pl.title('Sliced Wasserstein gradient flow')
+pl.legend()
+ax = pl.axis()
+
+# %%
+# Animate trajectories of the gradient flow along iteration
+# -------------------------------------------------------
+
+pl.figure(3, (8, 4))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+ pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+ pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
+ pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i))
+ pl.axis(ax)
+ return 1
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
+
+# %%
+# Compute the Sliced Wasserstein Barycenter
+#
+x1_torch = torch.tensor(x1).to(device=device)
+x3_torch = torch.tensor(x3).to(device=device)
+xbinit = np.random.randn(500, 2) * 10 + 16
+xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)
+
+lr = 1e3
+nb_iter_max = 100
+
+x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
+
+loss_iter = []
+
+# generator for random permutations
+gen = torch.Generator()
+gen.manual_seed(42)
+
+alpha = 0.5
+
+for i in range(nb_iter_max):
+
+ loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \
+ + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = xbary_torch.grad
+ xbary_torch -= grad * lr # / (1 + i / 5e1) # step
+ xbary_torch.grad.zero_()
+ x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy()
+
+xb = xbary_torch.clone().detach().cpu().numpy()
+
+pl.figure(4, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$')
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter')
+pl.title('Sliced Wasserstein barycenter')
+pl.legend()
+ax = pl.axis()
+
+
+# %%
+# Animate trajectories of the barycenter along gradient descent
+# -------------------------------------------------------
+
+pl.figure(5, (8, 4))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+ pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+ pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
+ pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i))
+ pl.axis(ax)
+ return 1
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py
new file mode 100644
index 0000000..0abdd6d
--- /dev/null
+++ b/examples/backends/plot_wass1d_torch.py
@@ -0,0 +1,152 @@
+r"""
+=================================
+Wasserstein 1D with PyTorch
+=================================
+
+In this small example, we consider the following minization problem:
+
+.. math::
+ \mu^* = \min_\mu W(\mu,\nu)
+
+where :math:`\nu` is a reference 1D measure. The problem is handled
+by a projected gradient descent method, where the gradient is computed
+by pyTorch automatic differentiation. The projection on the simplex
+ensures that the iterate will remain on the probability simplex.
+
+This example illustrates both `wasserstein_1d` function and backend use within
+the POT framework.
+"""
+# Author: Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import matplotlib as mpl
+import torch
+
+from ot.lp import wasserstein_1d
+from ot.datasets import make_1D_gauss as gauss
+from ot.utils import proj_simplex
+
+red = np.array(mpl.colors.to_rgb('red'))
+blue = np.array(mpl.colors.to_rgb('blue'))
+
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# enforce sum to one on the support
+a = a / a.sum()
+b = b / b.sum()
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x_torch = torch.tensor(x).to(device=device)
+a_torch = torch.tensor(a).to(device=device).requires_grad_(True)
+b_torch = torch.tensor(b).to(device=device)
+
+lr = 1e-6
+nb_iter_max = 800
+
+loss_iter = []
+
+pl.figure(1, figsize=(8, 4))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+
+for i in range(nb_iter_max):
+ # Compute the Wasserstein 1D with torch backend
+ loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2)
+ # record the corresponding loss value
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a_torch.grad
+ a_torch -= a_torch.grad * lr # step
+ a_torch.grad.zero_()
+ a_torch.data = proj_simplex(a_torch) # projection onto the simplex
+
+ # plot one curve every 10 iterations
+ if i % 10 == 0:
+ mix = float(i) / nb_iter_max
+ pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red)
+
+pl.legend()
+pl.title('Distribution along the iterations of the projected gradient descent')
+pl.show()
+
+pl.figure(2)
+pl.plot(range(nb_iter_max), loss_iter, lw=3)
+pl.title('Evolution of the loss along iterations', fontsize=16)
+pl.show()
+
+# %%
+# Wasserstein barycenter
+# ---------
+# In this example, we consider the following Wasserstein barycenter problem
+# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$
+# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t`
+# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient
+# descent method, where the gradient is computed by pyTorch automatic differentiation.
+# The projection on the simplex ensures that the iterate will remain on the
+# probability simplex.
+#
+# This example illustrates both `wasserstein_1d` function and backend use within the
+# POT framework.
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x_torch = torch.tensor(x).to(device=device)
+a_torch = torch.tensor(a).to(device=device)
+b_torch = torch.tensor(b).to(device=device)
+bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True)
+
+
+lr = 1e-6
+nb_iter_max = 2000
+
+loss_iter = []
+
+# instant of the interpolation
+t = 0.5
+
+for i in range(nb_iter_max):
+ # Compute the Wasserstein 1D with torch backend
+ loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2)
+ # record the corresponding loss value
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = bary_torch.grad
+ bary_torch -= bary_torch.grad * lr # step
+ bary_torch.grad.zero_()
+ bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex
+
+pl.figure(3, figsize=(8, 4))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter')
+pl.legend()
+pl.title('Wasserstein barycenter computed by gradient descent')
+pl.show()
+
+pl.figure(4)
+pl.plot(range(nb_iter_max), loss_iter, lw=3)
+pl.title('Evolution of the loss along iterations', fontsize=16)
+pl.show()
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
index 27df656..7d73907 100644
--- a/examples/sliced-wasserstein/plot_variance.py
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -63,7 +63,7 @@ res = np.empty((n_seed, 25))
# %% Compute statistics
for seed in range(n_seed):
for i, n_projections in enumerate(n_projections_arr):
- res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed)
+ res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed)
res_mean = np.mean(res, axis=0)
res_std = np.std(res, axis=0)