summaryrefslogtreecommitdiff
path: root/ot/stochastic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r--ot/stochastic.py242
1 files changed, 240 insertions, 2 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 693675f..61be9bb 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -4,12 +4,14 @@ Stochastic solvers for regularized OT.
"""
-# Author: Kilian Fatras <kilian.fatras@gmail.com>
+# Authors: Kilian Fatras <kilian.fatras@gmail.com>
+# RĂ©mi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
import numpy as np
-
+from .utils import dist
+from .backend import get_backend
##############################################################################
# Optimization toolbox for SEMI - DUAL problems
@@ -747,3 +749,239 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
return pi, log
else:
return pi
+
+
+################################################################################
+# Losses for stochastic optimization
+################################################################################
+
+def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the dual loss of the entropic OT as in equation (6)-(7) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ dual_loss : array-like
+ Dual loss (to maximize)
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ F = -reg * nx.exp((u[:, None] + v[None, :] - M) / reg)
+
+ return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :])
+
+
+def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the primal OT plan the entropic OT as in equation (8) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ G : array-like
+ Primal OT plan
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ H = nx.exp((u[:, None] + v[None, :] - M) / reg)
+
+ return ws[:, None] * H * wt[None, :]
+
+
+def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ dual_loss : array-like
+ Dual loss (to maximize)
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2
+
+ return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :])
+
+
+def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ G : array-like
+ Primal OT plan
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ H = 1.0 / (2 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)
+
+ return ws[:, None] * H * wt[None, :]