diff options
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r-- | ot/unbalanced.py | 808 |
1 files changed, 656 insertions, 152 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 50ec03c..d516dfc 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -9,51 +9,56 @@ Regularized Unbalanced OT from __future__ import division import warnings import numpy as np +from scipy.special import logsumexp + # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" - Solve the unbalanced entropic regularization optimal transport problem and return the loss + Solve the unbalanced entropic regularization optimal transport problem + and return the OT plan The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization + term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns, nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshol on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -62,10 +67,16 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, Returns ------- - W : (nt) ndarray or float - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` Examples -------- @@ -82,83 +93,96 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems + (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] - ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_stabilized_unbalanced: + Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_reg_scaling_unbalanced: + Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: - raise ValueError('Unknown method. Using classic Sinkhorn Knopp') + raise ValueError("Unknown method '%s'." % method) - return sink() - -def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', - numItermax=1000, stopThr=1e-9, verbose=False, +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', + numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" - Solve the entropic regularization unbalanced optimal transport problem and return the loss + Solve the entropic regularization unbalanced optimal transport problem and + return the loss The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -171,10 +195,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', Returns ------- - W : (nt) ndarray or float - Optimal transportation matrix for the given parameters + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` log : dict - log dictionary return only if log==True in parameters + log dictionary returned only if `log` is `True` Examples -------- @@ -191,64 +215,70 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems + (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ - - if method.lower() == 'sinkhorn': - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - else: - raise ValueError('Unknown method. Using classic Sinkhorn Knopp') - b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: b = b[:, None] - - return sink() + if method.lower() == 'sinkhorn': + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix + - M is the (dim_a, dim_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ @@ -256,21 +286,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -279,11 +309,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, Returns ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` Examples -------- @@ -298,9 +333,13 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, References ---------- - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- @@ -313,12 +352,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - n_a, n_b = M.shape + dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(n_a, dtype=np.float64) / n_a + a = np.ones(dim_a, dtype=np.float64) / dim_a if len(b) == 0: - b = np.ones(n_b, dtype=np.float64) / n_b + b = np.ones(dim_b, dtype=np.float64) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -331,21 +370,19 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((n_a, 1)) / n_a - v = np.ones((n_b, n_hists)) / n_b - a = a.reshape(n_a, 1) + u = np.ones((dim_a, 1)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + a = a.reshape(dim_a, 1) else: - u = np.ones(n_a) / n_a - v = np.ones(n_b) / n_b + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b - # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - # print(np.min(K)) - fi = alpha / (alpha + reg) + fi = reg_m / (reg_m + reg) cpt = 0 err = 1. @@ -364,15 +401,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %s' % cpt) u = uprev v = vprev break if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: @@ -380,10 +418,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + cpt += 1 + if log: - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -400,9 +439,224 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, return u[:, None] * K * v[None, :] -def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, + **kwargs): + r""" + Solve the entropic regularization unbalanced optimal transport + problem and return the loss + + The function solves the following optimization problem using log-domain + stabilization as proposed in [10]: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization + term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + tau : float + thershold for max value in u or v for log scaling + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) + + References + ---------- + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = np.ones(dim_a, dtype=np.float64) / dim_a + if len(b) == 0: + b = np.ones(dim_b, dtype=np.float64) / dim_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((dim_a, n_hists)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + a = a.reshape(dim_a, 1) + else: + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(dim_a) + beta = np.zeros(dim_b) + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + reg_m)) + f_beta = np.exp(- beta / (reg + reg_m)) + + if n_hists: + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((a / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + v = ((b / (Ktu + 1e-16)) ** fi) * f_beta + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + if n_hists: + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + else: + alpha = alpha + reg * np.log(np.max(u)) + beta = beta + reg * np.log(np.max(v)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + u = uprev + v = vprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), + 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`.") + if n_hists: + logu = alpha[:, None] / reg + np.log(u) + logv = beta[:, None] / reg + np.log(v) + else: + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + if log: + log['logu'] = logu + log['logv'] = logv + if n_hists: # return only loss + res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + + logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) + res = np.exp(res) + if log: + return res, log + else: + return res + + else: # return OT matrix + ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + if log: + return ot_matrix, log + else: + return ot_matrix + + +def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): + r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization. The function solves the following optimization problem: @@ -411,28 +665,35 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - - alpha is the marginal relaxation hyperparameter - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of + matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ Parameters ---------- - A : np.ndarray (d,n) - n training distributions a_i of size d - M : np.ndarray (d,d) - loss matrix for OT + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. reg : float Entropy regularization term > 0 - alpha : float + reg_m : float Marginal relaxation term > 0 - weights : np.ndarray (n,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + tau : float + Stabilization threshold for log domain absorption. + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -441,7 +702,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, Returns ------- - a : (d,) ndarray + a : (dim,) ndarray Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -450,12 +711,165 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, + G. (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. """ - p, n_hists = A.shape + dim, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + fi = reg_m / (reg_m + reg) + + u = np.ones((dim, n_hists)) / dim + v = np.ones((dim, n_hists)) / dim + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(dim) + beta = np.zeros(dim) + q = np.ones(dim) / dim + while (err > stopThr and cpt < numItermax): + qprev = q + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + reg_m)) + f_beta = np.exp(- beta / (reg + reg_m)) + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((A / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + q = (Ktu ** (1 - fi)) * f_beta + q = q.dot(weights) ** (1 / (1 - fi)) + Q = q[:, None] + v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + q = qprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(q - qprev).max() / max(abs(q).max(), + abs(qprev).max(), 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`.") + if log: + log['niter'] = cpt + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) + return q, log + else: + return q + + +def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): + r"""Compute the entropic unbalanced wasserstein barycenter of A. + + The function solves the following optimization problem with a + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprin + arXiv:1607.05816. + + + """ + dim, n_hists = A.shape if weights is None: weights = np.ones(n_hists) / n_hists else: @@ -466,10 +880,10 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, K = np.exp(- M / reg) - fi = alpha / (alpha + reg) + fi = reg_m / (reg_m + reg) - v = np.ones((p, n_hists)) / p - u = np.ones((p, 1)) / p + v = np.ones((dim, n_hists)) / dim + u = np.ones((dim, 1)) / dim cpt = 0 err = 1. @@ -498,8 +912,11 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ - np.sum((v - vprev) ** 2) / np.sum((v) ** 2) + err_u = abs(u - uprev).max() + err_u /= max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() + err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: @@ -511,8 +928,95 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, cpt += 1 if log: log['niter'] = cpt - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) return q, log else: return q + + +def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): + r"""Compute the entropic unbalanced wasserstein barycenter of A. + + The function solves the following optimization problem with a + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprin + arXiv:1607.05816. + + """ + + if method.lower() == 'sinkhorn': + return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return barycenter_unbalanced_stabilized(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return barycenter_unbalanced(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) |