summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py100
1 files changed, 57 insertions, 43 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 2c18a88..5da897d 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -62,7 +62,7 @@ def center_ot_dual(alpha0, beta0, a=None, b=None):
is the following:
.. math::
- \alpha^T a= \beta^T b
+ \alpha^T \mathbf{a} = \beta^T \mathbf{b}
in addition to the OT problem constraints.
@@ -70,11 +70,11 @@ def center_ot_dual(alpha0, beta0, a=None, b=None):
a constant from both :math:`\alpha_0` and :math:`\beta_0`.
.. math::
- c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
+ c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}}
- \alpha=\alpha_0+c
+ \alpha &= \alpha_0 + c
- \beta=\beta0+c
+ \beta &= \beta_0 + c
Parameters
----------
@@ -117,7 +117,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
The feasible values are computed efficiently but rather coarsely.
.. warning::
- This function is necessary because the C++ solver in emd_c
+ This function is necessary because the C++ solver in `emd_c`
discards all samples in the distributions with
zeros weights. This means that while the primal variable (transport
matrix) is exact, the solver only returns feasible dual potentials
@@ -126,26 +126,26 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
First we compute the constraints violations:
.. math::
- V=\alpha+\beta^T-M
+ \mathbf{V} = \alpha + \beta^T - \mathbf{M}
- Next we compute the max amount of violation per row (alpha) and
- columns (beta)
+ Next we compute the max amount of violation per row (:math:`\alpha`) and
+ columns (:math:`beta`)
.. math::
- v^a_i=\max_j V_{i,j}
+ \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j}
- v^b_j=\max_i V_{i,j}
+ \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j}
Finally we update the dual potential with 0 weights if a
constraint is violated
.. math::
- \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
+ \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0
- \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
+ \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0
In the end the dual potentials are centered using function
- :ref:`center_ot_dual`.
+ :py:func:`ot.lp.center_ot_dual`.
Note that all those updates do not change the objective value of the
solution but provide dual potentials that do not violate the constraints.
@@ -201,26 +201,28 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
r"""Solves the Earth Movers distance problem and returns the OT matrix
- .. math:: \gamma = arg\min_\gamma <\gamma,M>_F
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} = \mathbf{b}
- \gamma^T 1= b
+ \gamma \geq 0
- \gamma\geq 0
where :
- - M is the metric cost matrix
- - a and b are the sample weights
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- .. warning:: Note that the M matrix in numpy needs to be a C-order
+ .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order
numpy.array in float64 format. It will be converted if not in this
format
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
- Uses the algorithm proposed in [1]_
+ Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
----------
@@ -267,17 +269,19 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
array([[0.5, 0. ],
[0. , 0.5]])
+
+ .. _references-emd:
References
----------
-
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
December). Displacement interpolation using Lagrangian mass transport.
In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See Also
--------
- ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
- regularized OT"""
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+ """
# convert to numpy if list
a, b, M = list_to_array(a, b, M)
@@ -340,22 +344,23 @@ def emd2(a, b, M, processes=1,
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
- \min_\gamma <\gamma,M>_F
+ \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
- \gamma\geq 0
where :
- - M is the metric cost matrix
- - a and b are the sample weights
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
- Uses the algorithm proposed in [1]_
+ Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
----------
@@ -405,9 +410,10 @@ def emd2(a, b, M, processes=1,
>>> ot.emd2(a,b,M)
0.0
+
+ .. _references-emd2:
References
----------
-
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
(2011, December). Displacement interpolation using Lagrangian mass
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
@@ -416,7 +422,8 @@ def emd2(a, b, M, processes=1,
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
+ ot.optim.cg : General regularized OT
+ """
a, b, M = list_to_array(a, b, M)
@@ -508,29 +515,35 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:
.. math::
- \min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i)
+ \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
where :
- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
- - the :math:`a_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
- - the :math:`X_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
- - :math:`b \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+ - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
+ - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
+ - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
- This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
+ This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2).
+ There are two differences with the following codes:
- we do not optimize over the weights
- - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
+ - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+ :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ implementation of the fixed-point algorithm of
+ :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting.
Parameters
----------
measures_locations : list of N (k_i,d) numpy.ndarray
- The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
+ The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
measures_weights : list of N (k_i,) numpy.ndarray
- Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
+ Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+ representing the weights of each discrete input measure
X_init : (k,d) np.ndarray
- Initialization of the support locations (on k atoms) of the barycenter
+ Initialization of the support locations (on `k` atoms) of the barycenter
b : (k,) np.ndarray
Initialization of the weights of the barycenter (non-negatives, sum to 1)
weights : (N,) np.ndarray
@@ -554,9 +567,10 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
X : (k,d) np.ndarray
Support locations (on k atoms) of the barycenter
+
+ .. _references-free-support-barycenter:
References
----------
-
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.