From 10accb13c2f22c946b65b249d7aae6e4f6af7579 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 22 Jul 2019 14:53:45 +0200 Subject: add unbalanced with stabilization --- ot/unbalanced.py | 279 ++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 245 insertions(+), 34 deletions(-) (limited to 'ot') diff --git a/ot/unbalanced.py b/ot/unbalanced.py index f6c2d5f..ca24e8b 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -9,10 +9,12 @@ Regularized Unbalanced OT from __future__ import division import warnings import numpy as np +from scipy.misc import logsumexp + # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem and return the loss @@ -20,7 +22,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -45,11 +47,11 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, loss matrix reg : float Entropy regularization term > 0 - alpha : float + mu : float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -95,22 +97,29 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, -------- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_reg_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + elif method.lower() == 'sinkhorn_stabilized': + def sink(): + return sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -120,7 +129,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, return sink() -def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', +def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" @@ -129,7 +138,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -154,11 +163,11 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', loss matrix reg : float Entropy regularization term > 0 - alpha : float + mu : float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -203,22 +212,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + elif method.lower() == 'sinkhorn_stabilized': + def sink(): + return sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + return sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -232,7 +248,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', return sink() -def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, +def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -240,7 +256,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) s.t. \gamma\geq 0 @@ -265,7 +281,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, loss matrix reg : float Entropy regularization term > 0 - alpha : float + mu : float Marginal relaxation term > 0 numItermax : int, optional Max number of iterations @@ -338,14 +354,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, u = np.ones(n_a) / n_a v = np.ones(n_b) / n_b - # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - # print(np.min(K)) - fi = alpha / (alpha + reg) + fi = mu / (mu + reg) cpt = 0 err = 1. @@ -371,8 +385,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.) - err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.) + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -383,8 +397,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 if log: - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -401,7 +415,204 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, return u[:, None] * K * v[None, :] -def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, +def sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, tau=1e5, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, + **kwargs): + r""" + Solve the entropic regularization unbalanced optimal transport problem and return the loss + + The function solves the following optimization problem using log-domain + stabilization as proposed in [10]: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (ns, nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Entropy regularization term > 0 + mu : float + Marginal relaxation term > 0 + tau : float + thershold for max value in u or v for log scaling + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) + + References + ---------- + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + n_a, n_b = M.shape + + if len(a) == 0: + a = np.ones(n_a, dtype=np.float64) / n_a + if len(b) == 0: + b = np.ones(n_b, dtype=np.float64) / n_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((n_a, n_hists)) / n_a + v = np.ones((n_b, n_hists)) / n_b + a = a.reshape(n_a, 1) + else: + u = np.ones(n_a) / n_a + v = np.ones(n_b) / n_b + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = mu / (mu + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(n_a) + beta = np.zeros(n_b) + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + mu)) + f_beta = np.exp(- beta / (reg + mu)) + + if n_hists: + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((a / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + v = ((b / (Ktu + 1e-16)) ** fi) * f_beta + if (u > tau).any() or (v > tau).any(): + if n_hists: + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + else: + alpha = alpha + reg * np.log(np.max(u)) + beta = beta + reg * np.log(np.max(v)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %d' % cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), + 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if n_hists: + logu = alpha[:, None] / reg + np.log(u) + logv = beta[:, None] / reg + np.log(v) + else: + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + if log: + log['logu'] = logu + log['logv'] = logv + if n_hists: # return only loss + res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + + logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) + res = np.exp(res) + if log: + return res, log + else: + return res + + else: # return OT matrix + ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + if log: + return ot_matrix, log + else: + return ot_matrix + + +def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A @@ -415,7 +626,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - - alpha is the marginal relaxation hyperparameter + - mu is the marginal relaxation hyperparameter The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ Parameters @@ -426,7 +637,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, loss matrix for OT reg : float Entropy regularization term > 0 - alpha : float + mu : float Marginal relaxation term > 0 weights : np.ndarray (n,) Weights of each histogram a_i on the simplex (barycentric coodinates) @@ -467,7 +678,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, K = np.exp(- M / reg) - fi = alpha / (alpha + reg) + fi = mu / (mu + reg) v = np.ones((p, n_hists)) / p u = np.ones((p, 1)) / p @@ -499,8 +710,8 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.) - err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.) + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -513,8 +724,8 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, cpt += 1 if log: log['niter'] = cpt - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) return q, log else: return q -- cgit v1.2.3