diff options
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r-- | ot/unbalanced.py | 189 |
1 files changed, 186 insertions, 3 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 90c920c..a71a0dd 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -10,6 +10,9 @@ Regularized Unbalanced OT solvers from __future__ import division import warnings +import numpy as np +from scipy.optimize import minimize, Bounds + from .backend import get_backend from .utils import list_to_array # from .utils import unif, dist @@ -269,7 +272,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" - Solve the entropic regularization unbalanced optimal transport problem and return the loss + Solve the entropic regularization unbalanced optimal transport problem and + return the OT plan The function solves the following optimization problem: @@ -734,7 +738,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: - assert(len(weights) == A.shape[1]) + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -882,7 +886,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: - assert(len(weights) == A.shape[1]) + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -1252,3 +1256,182 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, return log_mm['cost'], log_mm else: return log_mm['cost'] + + +def _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl'): + """ + return the loss function (scipy.optimize compatible) for regularized + unbalanced OT + """ + + m, n = M.shape + + def kl(p, q): + return np.sum(p * np.log(p / q + 1e-16)) + + def reg_l2(G): + return np.sum((G - a[:, None] * b[None, :])**2) / 2 + + def grad_l2(G): + return G - a[:, None] * b[None, :] + + def reg_kl(G): + return kl(G, a[:, None] * b[None, :]) + + def grad_kl(G): + return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + 1 + + def reg_entropy(G): + return kl(G, 1) + + def grad_entropy(G): + return np.log(G + 1e-16) + 1 + + if reg_div == 'kl': + reg_fun = reg_kl + grad_reg_fun = grad_kl + elif reg_div == 'entropy': + reg_fun = reg_entropy + grad_reg_fun = grad_entropy + else: + reg_fun = reg_l2 + grad_reg_fun = grad_l2 + + def marg_l2(G): + return 0.5 * np.sum((G.sum(1) - a)**2) + 0.5 * np.sum((G.sum(0) - b)**2) + + def grad_marg_l2(G): + return np.outer((G.sum(1) - a), np.ones(n)) + np.outer(np.ones(m), (G.sum(0) - b)) + + def marg_kl(G): + return kl(G.sum(1), a) + kl(G.sum(0), b) + + def grad_marg_kl(G): + return np.outer(np.log(G.sum(1) / a + 1e-16) + 1, np.ones(n)) + np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16) + 1) + + if regm_div == 'kl': + regm_fun = marg_kl + grad_regm_fun = grad_marg_kl + else: + regm_fun = marg_l2 + grad_regm_fun = grad_marg_l2 + + def _func(G): + G = G.reshape((m, n)) + + # compute loss + val = np.sum(G * M) + reg * reg_fun(G) + reg_m * regm_fun(G) + + # compute gradient + grad = M + reg * grad_reg_fun(G) + reg_m * grad_regm_fun(G) + + return val, grad.ravel() + + return _func + + +def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, + stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + + \mathrm{reg} \mathrm{div}(\gamma,\mathbf{a}\mathbf{b}^T) + \mathrm{reg_m} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg: float + regularization term (>=0) + reg_m: float + Marginal relaxation term >= 0 + reg_div: string, optional + Divergence used for regularization. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + reg_div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + ot_distance : array-like + the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2) + 0.25 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2) + 0.57 + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., FĂ©votte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + nx = get_backend(M, a, b) + + M0 = M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + + if G0 is not None: + G0 = nx.to_numpy(G0) + else: + G0 = np.zeros(M.shape) + + _func = _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div, regm_div) + + res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), + tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) + + G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0) + + if log: + log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res} + return G, log + else: + return G |