From 0afd84d744a472903d427e3c7ae32e55fdd7b9a7 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 4 Apr 2022 10:23:04 +0200 Subject: [WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT (#360) * add losses and plan computations and exmaple for dual oiptimization * pep8 * add nice exmaple * update awesome example stochasti dual * add all tests * pep8 + speedup exmaple * add release info --- RELEASES.md | 2 + examples/backends/plot_dual_ot_pytorch.py | 168 ++++++++++++++ .../backends/plot_stoch_continuous_ot_pytorch.py | 189 ++++++++++++++++ ot/stochastic.py | 242 ++++++++++++++++++++- test/test_stochastic.py | 115 +++++++++- 5 files changed, 713 insertions(+), 3 deletions(-) create mode 100644 examples/backends/plot_dual_ot_pytorch.py create mode 100644 examples/backends/plot_stoch_continuous_ot_pytorch.py diff --git a/RELEASES.md b/RELEASES.md index c2bd0d1..45336f7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,8 @@ #### New features +- Add stochastic loss and OT plan computation for regularized OT and + backend examples(PR #360). - Implementation of factored OT with emd and sinkhorn (PR #358). - A brand new logo for POT (PR #357) - Better list of related examples in quick start guide with `minigallery` (PR #334). diff --git a/examples/backends/plot_dual_ot_pytorch.py b/examples/backends/plot_dual_ot_pytorch.py new file mode 100644 index 0000000..d3f7a66 --- /dev/null +++ b/examples/backends/plot_dual_ot_pytorch.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +r""" +====================================================================== +Dual OT solvers for entropic and quadratic regularized OT with Pytorch +====================================================================== + + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +import ot +import ot.plot + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +n_source_samples = 100 +n_target_samples = 100 +theta = 2 * np.pi / 20 +noise_level = 0.1 + +Xs, ys = ot.datasets.make_data_classif( + 'gaussrot', n_source_samples, nz=noise_level) +Xt, yt = ot.datasets.make_data_classif( + 'gaussrot', n_target_samples, theta=theta, nz=noise_level) + +# one of the target mode changes its variance (no linear mapping) +Xt[yt == 2] *= 3 +Xt = Xt + 4 + + +# %% +# Plot data +# --------- + +pl.figure(1, (10, 5)) +pl.clf() +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +# %% +# Convert data to torch tensors +# ----------------------------- + +xs = torch.tensor(Xs) +xt = torch.tensor(Xt) + +# %% +# Estimating dual variables for entropic OT +# ----------------------------------------- + +u = torch.randn(n_source_samples, requires_grad=True) +v = torch.randn(n_source_samples, requires_grad=True) + +reg = 0.5 + +optimizer = torch.optim.Adam([u, v], lr=1) + +# number of iteration +n_iter = 200 + + +losses = [] + +for i in range(n_iter): + + # generate noise samples + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(2) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + +Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg) + +# %% +# Plot teh estimated entropic OT plan +# ----------------------------------- + +pl.figure(3, (10, 5)) +pl.clf() +ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1) +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Estimating dual variables for quadratic OT +# ----------------------------------------- + +u = torch.randn(n_source_samples, requires_grad=True) +v = torch.randn(n_source_samples, requires_grad=True) + +reg = 0.01 + +optimizer = torch.optim.Adam([u, v], lr=1) + +# number of iteration +n_iter = 200 + + +losses = [] + + +for i in range(n_iter): + + # generate noise samples + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(4) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + +Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg) + + +# %% +# Plot the estimated quadratic OT plan +# ----------------------------------- + +pl.figure(5, (10, 5)) +pl.clf() +ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1) +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.legend(loc=0) +pl.title('OT plan with quadratic regularization') diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py new file mode 100644 index 0000000..6d9b916 --- /dev/null +++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +r""" +====================================================================== +Continuous OT plan estimation with Pytorch +====================================================================== + + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +from torch import nn +import ot +import ot.plot + +# %% +# Data generation +# --------------- + +torch.manual_seed(42) +np.random.seed(42) + +n_source_samples = 10000 +n_target_samples = 10000 +theta = 2 * np.pi / 20 +noise_level = 0.1 + +Xs = np.random.randn(n_source_samples, 2) * 0.5 +Xt = np.random.randn(n_target_samples, 2) * 2 + +# one of the target mode changes its variance (no linear mapping) +Xt = Xt + 4 + + +# %% +# Plot data +# --------- +nvisu = 300 +pl.figure(1, (5, 5)) +pl.clf() +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5) +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Source and target distributions') + +# %% +# Convert data to torch tensors +# ----------------------------- + +xs = torch.tensor(Xs) +xt = torch.tensor(Xt) + +# %% +# Estimating deep dual variables for entropic OT +# ---------------------------------------------- + +torch.manual_seed(42) + +# define the MLP model + + +class Potential(torch.nn.Module): + def __init__(self): + super(Potential, self).__init__() + self.fc1 = nn.Linear(2, 200) + self.fc2 = nn.Linear(200, 1) + self.relu = torch.nn.ReLU() # instead of Heaviside step fn + + def forward(self, x): + output = self.fc1(x) + output = self.relu(output) # instead of Heaviside step fn + output = self.fc2(output) + return output.ravel() + + +u = Potential().double() +v = Potential().double() + +reg = 1 + +optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005) + +# number of iteration +n_iter = 1000 +n_batch = 500 + + +losses = [] + +for i in range(n_iter): + + # generate noise samples + + iperms = torch.randint(0, n_source_samples, (n_batch,)) + ipermt = torch.randint(0, n_target_samples, (n_batch,)) + + xsi = xs[iperms] + xti = xt[ipermt] + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(2) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + + +# %% +# Plot the density on arget for a given source sample +# --------------------------------------------------- + + +nv = 100 +xl = np.linspace(ax_bounds[0], ax_bounds[1], nv) +yl = np.linspace(ax_bounds[2], ax_bounds[3], nv) + +XX, YY = np.meshgrid(xl, yl) + +xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1) + +wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2)) +wxg = wxg / np.sum(wxg) + +xg = torch.tensor(xg) +wxg = torch.tensor(wxg) + + +pl.figure(4, (12, 4)) +pl.clf() +pl.subplot(1, 3, 1) + +iv = 2 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') + +pl.subplot(1, 3, 2) + +iv = 3 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') + +pl.subplot(1, 3, 3) + +iv = 6 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') diff --git a/ot/stochastic.py b/ot/stochastic.py index 693675f..61be9bb 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -4,12 +4,14 @@ Stochastic solvers for regularized OT. """ -# Author: Kilian Fatras +# Authors: Kilian Fatras +# Rémi Flamary # # License: MIT License import numpy as np - +from .utils import dist +from .backend import get_backend ############################################################################## # Optimization toolbox for SEMI - DUAL problems @@ -747,3 +749,239 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, return pi, log else: return pi + + +################################################################################ +# Losses for stochastic optimization +################################################################################ + +def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the dual loss of the entropic OT as in equation (6)-(7) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + dual_loss : array-like + Dual loss (to maximize) + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + F = -reg * nx.exp((u[:, None] + v[None, :] - M) / reg) + + return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) + + +def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the primal OT plan the entropic OT as in equation (8) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + G : array-like + Primal OT plan + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + H = nx.exp((u[:, None] + v[None, :] - M) / reg) + + return ws[:, None] * H * wt[None, :] + + +def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + dual_loss : array-like + Dual loss (to maximize) + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2 + + return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) + + +def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + G : array-like + Primal OT plan + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + H = 1.0 / (2 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0) + + return ws[:, None] * H * wt[None, :] diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 736df32..2b5c0fb 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library. """ -# Author: Kilian Fatras +# Authors: Kilian Fatras +# Rémi Flamary # # License: MIT License @@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn(): G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) np.testing.assert_allclose( G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + + +def test_loss_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 + + +def test_loss_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 -- cgit v1.2.3