diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/backends/plot_sliced_wass_grad_flow_pytorch.py | 185 | ||||
-rw-r--r-- | examples/backends/plot_wass1d_torch.py | 152 | ||||
-rw-r--r-- | examples/sliced-wasserstein/plot_variance.py | 2 |
3 files changed, 338 insertions, 1 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py new file mode 100644 index 0000000..05b9952 --- /dev/null +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -0,0 +1,185 @@ +r""" +================================= +Sliced Wasserstein barycenter and gradient flow with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the sliced Wasserstein +loss between two empirical distributions [31]. + +In the first example one we perform a +gradient flow on the support of a distribution that minimize the sliced +Wassersein distance as poposed in [36]. + +In the second exemple we optimize with a gradient descent the sliced +Wasserstein barycenter between two distributions as in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of +measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions. In International Conference on +Machine Learning (pp. 4104-4113). PMLR. + + +""" +# Author: Rémi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + + +# %% +# Loading the data + + +import numpy as np +import matplotlib.pylab as pl +import torch +import ot +import matplotlib.animation as animation + +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] + +sz = I2.shape[0] +XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) + +x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0 +x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0 +x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0 + +pl.figure(1, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) + +# %% +# Sliced Wasserstein gradient flow with Pytorch +# --------------------------------------------- + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True) +x2_torch = torch.tensor(x2).to(device=device) + + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +for i in range(nb_iter_max): + + loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = x1_torch.grad + x1_torch -= grad * lr / (1 + i / 5e1) # step + x1_torch.grad.zero_() + x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy() + +xb = x1_torch.clone().detach().cpu().numpy() + +pl.figure(2, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$') +pl.title('Sliced Wasserstein gradient flow') +pl.legend() +ax = pl.axis() + +# %% +# Animate trajectories of the gradient flow along iteration +# ------------------------------------------------------- + +pl.figure(3, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) + +# %% +# Compute the Sliced Wasserstein Barycenter +# +x1_torch = torch.tensor(x1).to(device=device) +x3_torch = torch.tensor(x3).to(device=device) +xbinit = np.random.randn(500, 2) * 10 + 16 +xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True) + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +alpha = 0.5 + +for i in range(nb_iter_max): + + loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \ + + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = xbary_torch.grad + xbary_torch -= grad * lr # / (1 + i / 5e1) # step + xbary_torch.grad.zero_() + x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy() + +xb = xbary_torch.clone().detach().cpu().numpy() + +pl.figure(4, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter') +pl.title('Sliced Wasserstein barycenter') +pl.legend() +ax = pl.axis() + + +# %% +# Animate trajectories of the barycenter along gradient descent +# ------------------------------------------------------- + +pl.figure(5, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py new file mode 100644 index 0000000..0abdd6d --- /dev/null +++ b/examples/backends/plot_wass1d_torch.py @@ -0,0 +1,152 @@ +r""" +================================= +Wasserstein 1D with PyTorch +================================= + +In this small example, we consider the following minization problem: + +.. math:: + \mu^* = \min_\mu W(\mu,\nu) + +where :math:`\nu` is a reference 1D measure. The problem is handled +by a projected gradient descent method, where the gradient is computed +by pyTorch automatic differentiation. The projection on the simplex +ensures that the iterate will remain on the probability simplex. + +This example illustrates both `wasserstein_1d` function and backend use within +the POT framework. +""" +# Author: Nicolas Courty <ncourty@irisa.fr> +# Rémi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import matplotlib as mpl +import torch + +from ot.lp import wasserstein_1d +from ot.datasets import make_1D_gauss as gauss +from ot.utils import proj_simplex + +red = np.array(mpl.colors.to_rgb('red')) +blue = np.array(mpl.colors.to_rgb('blue')) + + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# enforce sum to one on the support +a = a / a.sum() +b = b / b.sum() + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device).requires_grad_(True) +b_torch = torch.tensor(b).to(device=device) + +lr = 1e-6 +nb_iter_max = 800 + +loss_iter = [] + +pl.figure(1, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a_torch.grad + a_torch -= a_torch.grad * lr # step + a_torch.grad.zero_() + a_torch.data = proj_simplex(a_torch) # projection onto the simplex + + # plot one curve every 10 iterations + if i % 10 == 0: + mix = float(i) / nb_iter_max + pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red) + +pl.legend() +pl.title('Distribution along the iterations of the projected gradient descent') +pl.show() + +pl.figure(2) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() + +# %% +# Wasserstein barycenter +# --------- +# In this example, we consider the following Wasserstein barycenter problem +# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$ +# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t` +# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient +# descent method, where the gradient is computed by pyTorch automatic differentiation. +# The projection on the simplex ensures that the iterate will remain on the +# probability simplex. +# +# This example illustrates both `wasserstein_1d` function and backend use within the +# POT framework. + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device) +b_torch = torch.tensor(b).to(device=device) +bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True) + + +lr = 1e-6 +nb_iter_max = 2000 + +loss_iter = [] + +# instant of the interpolation +t = 0.5 + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = bary_torch.grad + bary_torch -= bary_torch.grad * lr # step + bary_torch.grad.zero_() + bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex + +pl.figure(3, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter') +pl.legend() +pl.title('Wasserstein barycenter computed by gradient descent') +pl.show() + +pl.figure(4) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 27df656..7d73907 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -63,7 +63,7 @@ res = np.empty((n_seed, 25)) # %% Compute statistics for seed in range(n_seed): for i, n_projections in enumerate(n_projections_arr): - res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed) res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) |