From cfdbbd21642c6082164b84db78c2ead07499a113 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 19 Jul 2019 17:04:14 +0200 Subject: remove square in convergence check add unbalanced with stabilization add unbalanced tests with stabilization fix doctest examples add xvfb in travis remove explicit call xvfb in travis change alpha to reg_m minor flake8 remove redundant sink definitions + better doc and naming add stabilized unbalanced barycenter + add not converged warnings add test for stable barycenter add generic barycenter func + make method funcs private fix typo + add method test for barycenters fix doc examples + add xml to gitignore fix whitespace in example change logsumexp import - scipy deprecation warning fix doctest improve naming + add stable barycenter in bregman add test for stable bar + test the method arg in bregman --- ot/unbalanced.py | 803 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 653 insertions(+), 150 deletions(-) (limited to 'ot/unbalanced.py') diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 0f0692e..3f71d28 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -9,51 +9,56 @@ Regularized Unbalanced OT from __future__ import division import warnings import numpy as np +from scipy.special import logsumexp + # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" - Solve the unbalanced entropic regularization optimal transport problem and return the loss + Solve the unbalanced entropic regularization optimal transport problem + and return the OT plan The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization + term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns, nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshol on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -62,10 +67,16 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, Returns ------- - W : (nt) ndarray or float - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` Examples -------- @@ -82,83 +93,96 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems + (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] - ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_stabilized_unbalanced: + Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_reg_scaling_unbalanced: + Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: - raise ValueError('Unknown method. Using classic Sinkhorn Knopp') - - return sink() + raise ValueError("Unknown method '%s'." % method) -def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', - numItermax=1000, stopThr=1e-9, verbose=False, +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', + numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" - Solve the entropic regularization unbalanced optimal transport problem and return the loss + Solve the entropic regularization unbalanced optimal transport problem and + return the loss The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -171,10 +195,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', Returns ------- - W : (nt) ndarray or float - Optimal transportation matrix for the given parameters + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` log : dict - log dictionary return only if log==True in parameters + log dictionary returned only if `log` is `True` Examples -------- @@ -191,64 +215,70 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems + (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ - - if method.lower() == 'sinkhorn': - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - else: - raise ValueError('Unknown method. Using classic Sinkhorn Knopp') - b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: b = b[:, None] - - return sink() + if method.lower() == 'sinkhorn': + return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix + - M is the (dim_a, dim_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ @@ -256,16 +286,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 numItermax : int, optional Max number of iterations @@ -279,11 +309,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, Returns ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` Examples -------- @@ -291,16 +326,20 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) + >>> ot.unbalanced._sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) array([[0.51122823, 0.18807035], [0.18807035, 0.51122823]]) References ---------- - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- @@ -313,12 +352,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - n_a, n_b = M.shape + dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(n_a, dtype=np.float64) / n_a + a = np.ones(dim_a, dtype=np.float64) / dim_a if len(b) == 0: - b = np.ones(n_b, dtype=np.float64) / n_b + b = np.ones(dim_b, dtype=np.float64) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -331,21 +370,19 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((n_a, 1)) / n_a - v = np.ones((n_b, n_hists)) / n_b - a = a.reshape(n_a, 1) + u = np.ones((dim_a, 1)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + a = a.reshape(dim_a, 1) else: - u = np.ones(n_a) / n_a - v = np.ones(n_b) / n_b + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b - # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - # print(np.min(K)) - fi = alpha / (alpha + reg) + fi = reg_m / (reg_m + reg) cpt = 0 err = 1. @@ -371,8 +408,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: @@ -383,8 +421,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, cpt += 1 if log: - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -401,9 +439,224 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, return u[:, None] * K * v[None, :] -def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A +def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, + **kwargs): + r""" + Solve the entropic regularization unbalanced optimal transport + problem and return the loss + + The function solves the following optimization problem using log-domain + stabilization as proposed in [10]: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization + term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + tau : float + thershold for max value in u or v for log scaling + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced._sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) + + References + ---------- + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = np.ones(dim_a, dtype=np.float64) / dim_a + if len(b) == 0: + b = np.ones(dim_b, dtype=np.float64) / dim_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((dim_a, n_hists)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + a = a.reshape(dim_a, 1) + else: + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(dim_a) + beta = np.zeros(dim_b) + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + reg_m)) + f_beta = np.exp(- beta / (reg + reg_m)) + + if n_hists: + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((a / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + v = ((b / (Ktu + 1e-16)) ** fi) * f_beta + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + if n_hists: + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + else: + alpha = alpha + reg * np.log(np.max(u)) + beta = beta + reg * np.log(np.max(v)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + u = uprev + v = vprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), + 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`.") + if n_hists: + logu = alpha[:, None] / reg + np.log(u) + logv = beta[:, None] / reg + np.log(v) + else: + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + if log: + log['logu'] = logu + log['logv'] = logv + if n_hists: # return only loss + res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + + logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) + res = np.exp(res) + if log: + return res, log + else: + return res + + else: # return OT matrix + ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + if log: + return ot_matrix, log + else: + return ot_matrix + + +def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): + r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization. The function solves the following optimization problem: @@ -412,28 +665,184 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - - alpha is the marginal relaxation hyperparameter - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of + matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ Parameters ---------- - A : np.ndarray (d,n) - n training distributions a_i of size d - M : np.ndarray (d,d) - loss matrix for OT + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. reg : float Entropy regularization term > 0 - alpha : float + reg_m : float Marginal relaxation term > 0 - weights : np.ndarray (n,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + tau : float + Stabilization threshold for log domain absorption. + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, + G. (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + + """ + dim, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + fi = reg_m / (reg_m + reg) + + u = np.ones((dim, n_hists)) / dim + v = np.ones((dim, n_hists)) / dim + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(dim) + beta = np.zeros(dim) + q = np.ones(dim) / dim + while (err > stopThr and cpt < numItermax): + qprev = q + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + reg_m)) + f_beta = np.exp(- beta / (reg + reg_m)) + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((A / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + q = (Ktu ** (1 - fi)) * f_beta + q = q.dot(weights) ** (1 / (1 - fi)) + Q = q[:, None] + v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + q = qprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(q - qprev).max() / max(abs(q).max(), + abs(qprev).max(), 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`.") + if log: + log['niter'] = cpt + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) + return q, log + else: + return q + + +def _barycenter_unbalanced(A, M, reg, reg_m, weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): + r"""Compute the entropic unbalanced wasserstein barycenter of A. + + The function solves the following optimization problem with a + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -442,7 +851,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, Returns ------- - a : (d,) ndarray + a : (dim,) ndarray Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -451,12 +860,16 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprin + arXiv:1607.05816. """ - p, n_hists = A.shape + dim, n_hists = A.shape if weights is None: weights = np.ones(n_hists) / n_hists else: @@ -467,10 +880,10 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, K = np.exp(- M / reg) - fi = alpha / (alpha + reg) + fi = reg_m / (reg_m + reg) - v = np.ones((p, n_hists)) / p - u = np.ones((p, 1)) / p + v = np.ones((dim, n_hists)) / dim + u = np.ones((dim, 1)) / dim cpt = 0 err = 1. @@ -499,8 +912,11 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ - np.sum((v - vprev) ** 2) / np.sum((v) ** 2) + err_u = abs(u - uprev).max() + err_u /= max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() + err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: @@ -512,8 +928,95 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, cpt += 1 if log: log['niter'] = cpt - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) return q, log else: return q + + +def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): + r"""Compute the entropic unbalanced wasserstein barycenter of A. + + The function solves the following optimization problem with a + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprin + arXiv:1607.05816. + + """ + + if method.lower() == 'sinkhorn': + return _barycenter_unbalanced(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return _barycenter_unbalanced_stabilized(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return _barycenter_unbalanced(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) -- cgit v1.2.3 From 7efea812ad0b1c7e3783397dbd8f3ad802fb7ac2 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 3 Sep 2019 17:26:30 +0200 Subject: same for unbalanced --- ot/unbalanced.py | 102 +++++++++++++++++++++++++++---------------------------- 1 file changed, 51 insertions(+), 51 deletions(-) (limited to 'ot/unbalanced.py') diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 3f71d28..25e4cf5 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -120,23 +120,23 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -241,29 +241,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', if len(b.shape) < 2: b = b[:, None] if method.lower() == 'sinkhorn': - return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError('Unknown method %s.' % method) -def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -300,7 +300,7 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -439,9 +439,9 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, return u[:, None] * K * v[None, :] -def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, - **kwargs): +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, + **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -653,9 +653,9 @@ def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=100 return ot_matrix -def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False): +def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization. The function solves the following optimization problem: @@ -804,9 +804,9 @@ def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, return q -def _barycenter_unbalanced(A, M, reg, reg_m, weights=None, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False): +def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): r"""Compute the entropic unbalanced wasserstein barycenter of A. The function solves the following optimization problem with a @@ -1001,22 +1001,22 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, """ if method.lower() == 'sinkhorn': - return _barycenter_unbalanced(A, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return _barycenter_unbalanced_stabilized(A, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) + return barycenter_unbalanced_stabilized(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return _barycenter_unbalanced(A, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return barycenter_unbalanced(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) -- cgit v1.2.3 From 49d9b5cf4eecefdc0fff4db6c43e85d16e478efb Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 3 Sep 2019 17:35:23 +0200 Subject: fix doctest examples --- ot/unbalanced.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'ot/unbalanced.py') diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 25e4cf5..d516dfc 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -326,7 +326,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced._sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) + >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) array([[0.51122823, 0.18807035], [0.18807035, 0.51122823]]) @@ -510,7 +510,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced._sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) + >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) array([[0.51122823, 0.18807035], [0.18807035, 0.51122823]]) -- cgit v1.2.3