diff options
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r-- | ot/stochastic.py | 242 |
1 files changed, 240 insertions, 2 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py index 693675f..61be9bb 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -4,12 +4,14 @@ Stochastic solvers for regularized OT. """ -# Author: Kilian Fatras <kilian.fatras@gmail.com> +# Authors: Kilian Fatras <kilian.fatras@gmail.com> +# RĂ©mi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License import numpy as np - +from .utils import dist +from .backend import get_backend ############################################################################## # Optimization toolbox for SEMI - DUAL problems @@ -747,3 +749,239 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, return pi, log else: return pi + + +################################################################################ +# Losses for stochastic optimization +################################################################################ + +def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the dual loss of the entropic OT as in equation (6)-(7) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + dual_loss : array-like + Dual loss (to maximize) + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + F = -reg * nx.exp((u[:, None] + v[None, :] - M) / reg) + + return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) + + +def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the primal OT plan the entropic OT as in equation (8) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + G : array-like + Primal OT plan + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + H = nx.exp((u[:, None] + v[None, :] - M) / reg) + + return ws[:, None] * H * wt[None, :] + + +def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + dual_loss : array-like + Dual loss (to maximize) + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2 + + return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) + + +def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + G : array-like + Primal OT plan + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + H = 1.0 / (2 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0) + + return ws[:, None] * H * wt[None, :] |