From 3c53834d46f093f5770ec76748beb5667bebb6fa Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 15:50:00 +0200 Subject: add unbalanced sinkhorn algorithm --- ot/unbalanced.py | 404 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 404 insertions(+) create mode 100644 ot/unbalanced.py (limited to 'ot/unbalanced.py') diff --git a/ot/unbalanced.py b/ot/unbalanced.py new file mode 100644 index 0000000..8bd02eb --- /dev/null +++ b/ot/unbalanced.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +""" +Regularized Unbalanced OT +""" + +# Author: Hicham Janati +# License: MIT License + +import numpy as np +# from .utils import unif, dist + + +def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + u""" + Solve the unbalanced entropic regularization optimal transport problem and return the loss + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + alpha KL(\gamma 1, a) + alpha KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (ns, nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns, nt) + loss matrix + reg : float + Regularization term > 0 + alpha : float + Regulatization term > 0 + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + W : (nt) ndarray or float + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn2(a, b, M, 1, 1) + array([0.26894142]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + + """ + + if method.lower() == 'sinkhorn': + def sink(): + return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + # elif method.lower() == 'sinkhorn_stabilized': + # def sink(): + # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + # stopThr=stopThr, verbose=verbose, log=log, **kwargs) + # elif method.lower() == 'sinkhorn_epsilon_scaling': + # def sink(): + # return sinkhorn_epsilon_scaling( + # a, b, M, reg, numItermax=numItermax, + # stopThr=stopThr, verbose=verbose, log=log, **kwargs) + else: + print('Warning : unknown method. Falling back to classic Sinkhorn Knopp') + + def sink(): + return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + return sink() + + +def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + u""" + Solve the entropic regularization unbalanced optimal transport problem and return the loss + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + alpha KL(\gamma 1, a) + alpha KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (ns, nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term > 0 + alpha: float + Regularization term > 0 + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + W : (nt) ndarray or float + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .10] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.sinkhorn2(a, b, M, 1., 1.) + array([ 0.26894142]) + + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] + ot.bregman.greenkhorn : Greenkhorn [21] + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + + """ + + if method.lower() == 'sinkhorn': + def sink(): + return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + # elif method.lower() == 'sinkhorn_stabilized': + # def sink(): + # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + # stopThr=stopThr, verbose=verbose, log=log, **kwargs) + # elif method.lower() == 'sinkhorn_epsilon_scaling': + # def sink(): + # return sinkhorn_epsilon_scaling( + # a, b, M, reg, numItermax=numItermax, + # stopThr=stopThr, verbose=verbose, log=log, **kwargs) + else: + print('Warning : unknown method using classic Sinkhorn Knopp') + + def sink(): + return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs) + + b = np.asarray(b, dtype=np.float64) + if len(b.shape) < 2: + b = b[None, :] + + return sink() + + +def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): + """ + Solve the entropic regularization unbalanced optimal transport problem and return the loss + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + alpha KL(\gamma 1, a) + alpha KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (ns, nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term > 0 + alpha: float + Regularization term > 0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .15] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.sinkhorn(a, b, M, 1., 1.) + array([[ 0.36552929, 0.13447071], + [ 0.13447071, 0.36552929]]) + + + References + ---------- + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + n_a, n_b = M.shape + + if len(a) == 0: + a = np.ones(n_a, dtype=np.float64) / n_a + if len(b) == 0: + b = np.ones(n_b, dtype=np.float64) / n_b + + assert n_a == len(a) and n_b == len(b) + if b.ndim > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((n_a, n_hists)) / n_a + v = np.ones((n_b, n_hists)) / n_b + else: + u = np.ones(n_a) / n_a + v = np.ones(n_b) / n_b + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + # print(np.min(K)) + fi = alpha / (alpha + reg) + + cpt = 0 + err = 1. + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + u = (a / Kv) ** fi + Ktu = K.T.dot(u) + v = (b / Ktu) ** fi + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errors at iteration', cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ + np.sum((v - vprev)**2) / np.sum((v)**2) + if log: + log['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + if log: + log['u'] = u + log['v'] = v + + if n_hists: # return only loss + res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + if log: + return res, log + else: + return res + + else: # return OT matrix + + if log: + return u[:, None] * K * v[None, :], log + else: + return u[:, None] * K * v[None, :] -- cgit v1.2.3