summaryrefslogtreecommitdiff
path: root/ot/smooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/smooth.py')
-rw-r--r--ot/smooth.py183
1 files changed, 106 insertions, 77 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index 81f6a3e..6855005 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -47,15 +47,24 @@ from scipy.optimize import minimize
def projection_simplex(V, z=1, axis=None):
- """ Projection of x onto the simplex, scaled by z
+ r""" Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z`
- P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
+ .. math::
+ P\left(\mathbf{V}, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z}} \quad \|\mathbf{y} - \mathbf{V}\|^2
+
+ Parameters
+ ----------
+ V: ndarray, rank 2
z: float or array
- If array, len(z) must be compatible with V
+ If array, len(z) must be compatible with :math:`\mathbf{V}`
axis: None or int
- - axis=None: project V by P(V.ravel(); z)
- - axis=1: project each V[i] by P(V[i]; z[i])
- - axis=0: project each V[:, j] by P(V[:, j]; z[j])
+ - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), z)`
+ - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, z_i)`
+ - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, z_j)`
+
+ Returns
+ -------
+ projection: ndarray, shape :math:`\mathbf{V}`.shape
"""
if axis == 1:
n_features = V.shape[1]
@@ -77,12 +86,12 @@ def projection_simplex(V, z=1, axis=None):
class Regularization(object):
- """Base class for Regularization objects
+ r"""Base class for Regularization objects
Notes
-----
- This class is not intended for direct use but as aparent for true
- regularizatiojn implementation.
+ This class is not intended for direct use but as apparent for true
+ regularization implementation.
"""
def __init__(self, gamma=1.0):
@@ -98,40 +107,48 @@ class Regularization(object):
self.gamma = gamma
def delta_Omega(X):
- """
- Compute delta_Omega(X[:, j]) for each X[:, j].
- delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
+ r"""
+ Compute :math:`\delta_\Omega(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`.
+
+ .. math::
+ \delta_\Omega(\mathbf{x}) = \sup_{\mathbf{y} >= 0} \
+ \mathbf{y}^T \mathbf{x} - \Omega(\mathbf{y})
Parameters
----------
- X: array, shape = len(a) x len(b)
+ X: array, shape = (len(a), len(b))
Input array.
Returns
-------
- v: array, len(b)
- Values: v[j] = delta_Omega(X[:, j])
- G: array, len(a) x len(b)
- Gradients: G[:, j] = nabla delta_Omega(X[:, j])
+ v: array, (len(b), )
+ Values: :math:`\mathbf{v}_j = \delta_\Omega(\mathbf{X}_{:, j})`
+ G: array, (len(a), len(b))
+ Gradients: :math:`\mathbf{G}_{:, j} = \nabla \delta_\Omega(\mathbf{X}_{:, j})`
"""
raise NotImplementedError
def max_Omega(X, b):
- """
- Compute max_Omega_j(X[:, j]) for each X[:, j].
- max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].
+ r"""
+ Compute :math:`\mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`.
+
+ .. math::
+ \mathrm{max}_{\Omega, j}(\mathbf{x}) =
+ \sup_{\substack{\mathbf{y} >= 0 \ \sum_i \mathbf{y}_i = 1}}
+ \mathbf{y}^T \mathbf{x} - \frac{1}{\mathbf{b}_j} \Omega(\mathbf{b}_j \mathbf{y})
Parameters
----------
- X: array, shape = len(a) x len(b)
+ X: array, shape = (len(a), len(b))
Input array.
+ b: array, shape = (len(b), )
Returns
-------
- v: array, len(b)
- Values: v[j] = max_Omega_j(X[:, j])
- G: array, len(a) x len(b)
- Gradients: G[:, j] = nabla max_Omega_j(X[:, j])
+ v: array, (len(b), )
+ Values: :math:`\mathbf{v}_j = \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})`
+ G: array, (len(a), len(b))
+ Gradients: :math:`\mathbf{G}_{:, j} = \nabla \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})`
"""
raise NotImplementedError
@@ -192,7 +209,7 @@ class SquaredL2(Regularization):
def dual_obj_grad(alpha, beta, a, b, C, regul):
- """
+ r"""
Compute objective value and gradients of dual objective.
Parameters
@@ -203,19 +220,19 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
a: array, shape = len(a)
b: array, shape = len(b)
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
obj: float
Objective value (higher is better).
grad_alpha: array, shape = len(a)
- Gradient w.r.t. alpha.
+ Gradient w.r.t. `alpha`.
grad_beta: array, shape = len(b)
- Gradient w.r.t. beta.
+ Gradient w.r.t. `beta`.
"""
obj = np.dot(alpha, a) + np.dot(beta, b)
grad_alpha = a.copy()
@@ -242,13 +259,13 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Parameters
----------
- a: array, shape = len(a)
- b: array, shape = len(b)
+ a: array, shape = (len(a), )
+ b: array, shape = (len(b), )
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
method: str
Solver to be used (passed to `scipy.optimize.minimize`).
tol: float
@@ -258,8 +275,8 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Returns
-------
- alpha: array, shape = len(a)
- beta: array, shape = len(b)
+ alpha: array, shape = (len(a), )
+ beta: array, shape = (len(b), )
Dual potentials.
"""
@@ -302,10 +319,10 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
a: array, shape = len(a)
b: array, shape = len(b)
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a max_Omega(X) method.
+ Should implement a `max_Omega(X)` method.
Returns
-------
@@ -337,13 +354,13 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Parameters
----------
- a: array, shape = len(a)
- b: array, shape = len(b)
+ a: array, shape = (len(a), )
+ b: array, shape = (len(b), )
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a max_Omega(X) method.
+ Should implement a `max_Omega(X)` method.
method: str
Solver to be used (passed to `scipy.optimize.minimize`).
tol: float
@@ -353,7 +370,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Returns
-------
- alpha: array, shape = len(a)
+ alpha: array, shape = (len(a), )
Semi-dual potentials.
"""
@@ -371,7 +388,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
def get_plan_from_dual(alpha, beta, C, regul):
- """
+ r"""
Retrieve optimal transportation plan from optimal dual potentials.
Parameters
@@ -379,14 +396,14 @@ def get_plan_from_dual(alpha, beta, C, regul):
alpha: array, shape = len(a)
beta: array, shape = len(b)
Optimal dual potentials.
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
- T: array, shape = len(a) x len(b)
+ T: array, shape = (len(a), len(b))
Optimal transportation plan.
"""
X = alpha[:, np.newaxis] + beta - C
@@ -394,7 +411,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
def get_plan_from_semi_dual(alpha, b, C, regul):
- """
+ r"""
Retrieve optimal transportation plan from optimal semi-dual potentials.
Parameters
@@ -403,14 +420,14 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
Optimal semi-dual potentials.
b: array, shape = len(b)
Second input histogram (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
- T: array, shape = len(a) x len(b)
+ T: array, shape = (len(a), len(b))
Optimal transportation plan.
"""
X = alpha[:, np.newaxis] - C
@@ -422,19 +439,21 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
r"""
Solve the regularized OT problem in the dual and return the OT matrix
- The function solves the smooth relaxed dual formulation (7) in [17]_ :
+ The function solves the smooth relaxed dual formulation (7) in
+ :ref:`[17] <references-smooth-ot-dual>`:
.. math::
- \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j)
+ \max_{\alpha,\beta}\quad \mathbf{a}^T\alpha + \mathbf{b}^T\beta -
+ \sum_j \delta_\Omega \left(\alpha+\beta_j-\mathbf{m}_j \right)
where :
- - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`\mathbf{m}_j` is the j-th column of the cost matrix
- :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega`
- (See [17]_ Proposition 1).
+ (See :ref:`[17] <references-smooth-ot-dual>` Proposition 1).
The optimization algorithm is using gradient decent (L-BFGS by default).
@@ -444,21 +463,25 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
samples weights in the source domain
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
reg_type : str
- Regularization type, can be the following (default ='l2'):
- - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
- - 'l2' : Squared Euclidean regularization
+ Regularization type, can be the following (default ='l2'):
+
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn
+ :ref:`[2] <references-smooth-ot-dual>`)
+
+ - 'l2' : Squared Euclidean regularization
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -467,15 +490,15 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-smooth-ot-dual:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
@@ -514,21 +537,23 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
r"""
Solve the regularized OT problem in the semi-dual and return the OT matrix
- The function solves the smooth relaxed dual formulation (10) in [17]_ :
+ The function solves the smooth relaxed dual formulation (10) in
+ :ref:`[17] <references-smooth-ot-semi-dual>`:
.. math::
- \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b)
+ \max_{\alpha}\quad \mathbf{a}^T\alpha- \mathrm{OT}_\Omega^*(\alpha, \mathbf{b})
where :
.. math::
- OT_\Omega^*(\alpha,b)=\sum_j b_j
+ \mathrm{OT}_\Omega^*(\alpha,b)=\sum_j \mathbf{b}_j
- - :math:`\mathbf{m}_j` is the jth column of the cost matrix
- - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17]
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{m}_j` is the j-th column of the cost matrix
+ - :math:`\mathrm{OT}_\Omega^*(\alpha,b)` is defined in Eq. (9) in
+ :ref:`[17] <references-smooth-ot-semi-dual>`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The OT matrix can is reconstructed using [17]_ Proposition 2.
+ The OT matrix can is reconstructed using :ref:`[17] <references-smooth-ot-semi-dual>` Proposition 2.
The optimization algorithm is using gradient decent (L-BFGS by default).
@@ -538,21 +563,25 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
samples weights in the source domain
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
+ and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
reg_type : str
- Regularization type, can be the following (default ='l2'):
- - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
- - 'l2' : Squared Euclidean regularization
+ Regularization type, can be the following (default ='l2'):
+
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn
+ :ref:`[2] <references-smooth-ot-semi-dual>`)
+
+ - 'l2' : Squared Euclidean regularization
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -561,15 +590,15 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-smooth-ot-semi-dual:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).