summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py189
1 files changed, 186 insertions, 3 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 90c920c..a71a0dd 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -10,6 +10,9 @@ Regularized Unbalanced OT solvers
from __future__ import division
import warnings
+import numpy as np
+from scipy.optimize import minimize, Bounds
+
from .backend import get_backend
from .utils import list_to_array
# from .utils import unif, dist
@@ -269,7 +272,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the OT plan
The function solves the following optimization problem:
@@ -734,7 +738,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -882,7 +886,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -1252,3 +1256,182 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
return log_mm['cost'], log_mm
else:
return log_mm['cost']
+
+
+def _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl'):
+ """
+ return the loss function (scipy.optimize compatible) for regularized
+ unbalanced OT
+ """
+
+ m, n = M.shape
+
+ def kl(p, q):
+ return np.sum(p * np.log(p / q + 1e-16))
+
+ def reg_l2(G):
+ return np.sum((G - a[:, None] * b[None, :])**2) / 2
+
+ def grad_l2(G):
+ return G - a[:, None] * b[None, :]
+
+ def reg_kl(G):
+ return kl(G, a[:, None] * b[None, :])
+
+ def grad_kl(G):
+ return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + 1
+
+ def reg_entropy(G):
+ return kl(G, 1)
+
+ def grad_entropy(G):
+ return np.log(G + 1e-16) + 1
+
+ if reg_div == 'kl':
+ reg_fun = reg_kl
+ grad_reg_fun = grad_kl
+ elif reg_div == 'entropy':
+ reg_fun = reg_entropy
+ grad_reg_fun = grad_entropy
+ else:
+ reg_fun = reg_l2
+ grad_reg_fun = grad_l2
+
+ def marg_l2(G):
+ return 0.5 * np.sum((G.sum(1) - a)**2) + 0.5 * np.sum((G.sum(0) - b)**2)
+
+ def grad_marg_l2(G):
+ return np.outer((G.sum(1) - a), np.ones(n)) + np.outer(np.ones(m), (G.sum(0) - b))
+
+ def marg_kl(G):
+ return kl(G.sum(1), a) + kl(G.sum(0), b)
+
+ def grad_marg_kl(G):
+ return np.outer(np.log(G.sum(1) / a + 1e-16) + 1, np.ones(n)) + np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16) + 1)
+
+ if regm_div == 'kl':
+ regm_fun = marg_kl
+ grad_regm_fun = grad_marg_kl
+ else:
+ regm_fun = marg_l2
+ grad_regm_fun = grad_marg_l2
+
+ def _func(G):
+ G = G.reshape((m, n))
+
+ # compute loss
+ val = np.sum(G * M) + reg * reg_fun(G) + reg_m * regm_fun(G)
+
+ # compute gradient
+ grad = M + reg * grad_reg_fun(G) + reg_m * grad_regm_fun(G)
+
+ return val, grad.ravel()
+
+ return _func
+
+
+def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ + \mathrm{reg} \mathrm{div}(\gamma,\mathbf{a}\mathbf{b}^T)
+ \mathrm{reg_m} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg: float
+ regularization term (>=0)
+ reg_m: float
+ Marginal relaxation term >= 0
+ reg_div: string, optional
+ Divergence used for regularization.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ reg_div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ ot_distance : array-like
+ the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2)
+ 0.25
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2)
+ 0.57
+
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., FĂ©votte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd2 : Unregularized OT loss
+ ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
+ """
+ nx = get_backend(M, a, b)
+
+ M0 = M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
+ if G0 is not None:
+ G0 = nx.to_numpy(G0)
+ else:
+ G0 = np.zeros(M.shape)
+
+ _func = _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div, regm_div)
+
+ res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf),
+ tol=stopThr, options=dict(maxiter=numItermax, disp=verbose))
+
+ G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0)
+
+ if log:
+ log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res}
+ return G, log
+ else:
+ return G