summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py223
1 files changed, 223 insertions, 0 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 503cc1e..90c920c 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -4,6 +4,7 @@ Regularized Unbalanced OT solvers
"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
# License: MIT License
from __future__ import division
@@ -1029,3 +1030,225 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
+
+
+def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a maximization-
+ minimization algorithm as proposed in :ref:`[41] <references-regpath>`
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg_m: float
+ Marginal relaxation term > 0
+ div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ Returns
+ -------
+ gamma : (dim_a, dim_b) array-like
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2)
+ array([[0.3 , 0. ],
+ [0. , 0.07]])
+ >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2)
+ array([[0.25, 0. ],
+ [0. , 0. ]])
+
+
+ .. _references-regpath:
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT
+ """
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = nx.ones(dim_a, type_as=M) / dim_a
+ if len(b) == 0:
+ b = nx.ones(dim_b, type_as=M) / dim_b
+
+ if G0 is None:
+ G = a[:, None] * b[None, :]
+ else:
+ G = G0
+
+ if log:
+ log = {'err': [], 'G': []}
+
+ if div == 'kl':
+ K = nx.exp(M / - reg_m / 2)
+ elif div == 'l2':
+ K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2,
+ nx.zeros((dim_a, dim_b), type_as=M))
+ else:
+ warnings.warn("The div parameter should be either equal to 'kl' or \
+ 'l2': it has been set to 'kl'.")
+ div = 'kl'
+ K = nx.exp(M / - reg_m / 2)
+
+ for i in range(numItermax):
+ Gprev = G
+
+ if div == 'kl':
+ u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16))
+ v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16))
+ G = G * K * u[:, None] * v[None, :]
+ elif div == 'l2':
+ Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16
+ G = G * K / Gd
+
+ err = nx.sqrt(nx.sum((G - Gprev) ** 2))
+ if log:
+ log['err'].append(err)
+ log['G'].append(G)
+ if verbose:
+ print('{:5d}|{:8e}|'.format(i, err))
+ if err < stopThr:
+ break
+
+ if log:
+ log['cost'] = nx.sum(G * M)
+ return G, log
+ else:
+ return G
+
+
+def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a maximization-
+ minimization algorithm as proposed in :ref:`[41] <references-regpath>`
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg_m: float
+ Marginal relaxation term > 0
+ div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ ot_distance : array-like
+ the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2)
+ 0.25
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2)
+ 0.57
+
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd2 : Unregularized OT loss
+ ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
+ """
+ _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=True)
+
+ if log:
+ return log_mm['cost'], log_mm
+ else:
+ return log_mm['cost']