From e1b67c641da3b3e497db6811af2c200022b10302 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 3 Nov 2021 08:41:35 +0100 Subject: [WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- ot/bregman.py | 1491 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 1151 insertions(+), 340 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 0499b8e..786f151 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,7 +7,7 @@ Bregman projections solvers for entropic regularized OT # Nicolas Courty # Kilian Fatras # Titouan Vayer -# Hicham Janati +# Hicham Janati # Mokhtar Z. Alaya # Alexander Tong # Ievgen Redko @@ -25,7 +25,8 @@ from .backend import get_backend def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -43,8 +44,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -77,7 +80,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -94,6 +98,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -117,13 +123,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also @@ -131,37 +145,44 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling + :ref:`[9] ` :ref:`[10] ` """ if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log) + stopThr=stopThr, verbose=verbose, log=log, + warn=warn) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -179,13 +200,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] ` **Choosing a Sinkhorn solver** @@ -212,7 +236,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -228,6 +253,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -252,19 +279,27 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation - algorithms for optimal transport via Sinkhorn iteration, Advances in Neural - Information Processing Systems (NIPS) 31, 2017 - - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. - + algorithms for optimal transport via Sinkhorn iteration, + Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). + Interpolating between optimal transport and MMD using Sinkhorn + divergences. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2681-2690). PMLR. See Also -------- @@ -272,7 +307,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` + :ref:`[10] ` """ @@ -317,8 +353,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -335,10 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp + matrix scaling algorithm as proposed in :ref:`[2] ` Parameters @@ -347,7 +387,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -360,6 +401,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -384,7 +427,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 See Also @@ -427,9 +472,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K - cpt = 0 + err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): uprev = u vprev = v KtransposeU = nx.dot(K.T, u) @@ -441,11 +486,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Warning: numerical errors at iteration %d' % ii) u = uprev v = vprev break - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: @@ -457,13 +502,20 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['niter'] = ii log['u'] = u log['v'] = v @@ -482,8 +534,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_log(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space and return the OT matrix @@ -528,6 +580,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -552,9 +606,15 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). Interpolating between + optimal transport and MMD using Sinkhorn divergences. In The + 22nd International Conference on Artificial Intelligence and + Statistics (pp. 2681-2690). PMLR. See Also @@ -613,7 +673,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, if log: log = {'err': []} - Mr = M / (-reg) + Mr = - M / reg # we assume that no distances are null except those of the diagonal of # distances @@ -630,14 +690,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, loga = nx.log(a) logb = nx.log(b) - cpt = 0 err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): v = logb - nx.logsumexp(Mr + u[:, None], 0) u = loga - nx.logsumexp(Mr + v[None, :], 1) - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations @@ -648,13 +707,20 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['niter'] = ii log['log_u'] = u log['log_v'] = v log['u'] = nx.exp(u) @@ -667,11 +733,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False): + log=False, warn=True): r""" Solve the entropic regularization optimal transport problem and return the OT matrix - The algorithm used is based on the paper :ref:`[22] ` which is a stochastic version of the Sinkhorn-Knopp algorithm :ref:`[2] ` + The algorithm used is based on the paper :ref:`[22] ` + which is a stochastic version of the Sinkhorn-Knopp + algorithm :ref:`[2] ` The function solves the following optimization problem: @@ -686,8 +754,10 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) Parameters @@ -696,7 +766,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, samples weights in the source domain b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -707,6 +778,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, Stop threshold on error (>0) log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -731,9 +804,14 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time + approximation algorithms for optimal transport via Sinkhorn + iteration, Advances in Neural Information Processing + Systems (NIPS) 31, 2017 See Also @@ -747,7 +825,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, nx = get_backend(M, a, b) if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received. Greenkhorn is not compatible with JAX") + raise TypeError("JAX arrays have been received. Greenkhorn is not " + "compatible with JAX") if len(a) == 0: a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] @@ -771,7 +850,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log['u'] = u log['v'] = v - for i in range(numItermax): + for ii in range(numItermax): i_1 = nx.argmax(nx.abs(viol)) i_2 = nx.argmax(nx.abs(viol_2)) m_viol_1 = nx.abs(viol[i_1]) @@ -795,14 +874,17 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, viol += (-old_v + new_v) * K[:, i_2] * u viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] v[i_2] = new_v - # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: break else: - print('Warning: Algorithm did not converge') + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log["n_iter"] = ii log['u'] = u log['v'] = v @@ -814,7 +896,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization OT problem with log stabilization @@ -831,13 +913,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization - proposed in :ref:`[10] ` an defined in :ref:`[9] ` (Algo 3.1) . + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` an defined in + :ref:`[9] ` (Algo 3.1) . Parameters @@ -851,7 +937,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, reg : float Regularization term >0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling warmstart : table of vectors if given then starting values for alpha and beta log scalings numItermax : int, optional @@ -862,6 +949,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -886,11 +975,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. See Also @@ -920,7 +1015,6 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, dim_a = len(a) dim_b = len(b) - cpt = 0 if log: log = {'err': []} @@ -935,7 +1029,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, u = nx.ones((dim_a, n_hists), type_as=M) / dim_a v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = nx.ones(dim_a, type_as=M) / dim_a, nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M) + u /= dim_a + v /= dim_b def get_K(alpha, beta): """log space computation""" @@ -947,21 +1043,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) - # print(np.min(K)) - K = get_K(alpha, beta) transp = K - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): uprev = u vprev = v # sinkhorn update - v = b / (nx.dot(K.T, u) + 1e-16) - u = a / (nx.dot(K, v) + 1e-16) + v = b / (nx.dot(K.T, u)) + u = a / (nx.dot(K, v)) # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: @@ -977,7 +1069,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, v = nx.ones(dim_b, type_as=M) / dim_b K = get_K(alpha, beta) - if cpt % print_period == 0: + if ii % print_period == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: @@ -993,33 +1085,33 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, log['err'].append(err) if verbose: - if cpt % (print_period * 20) == 0: + if ii % (print_period * 20) == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) if err <= stopThr: - loop = False - - if cpt >= numItermax: - loop = False + break if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %d' % ii) u = uprev v = vprev break - - cpt = cpt + 1 - + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: if n_hists: alpha = alpha[:, None] beta = beta[:, None] logu = alpha / reg + nx.log(u) logv = beta / reg + nx.log(v) + log["n_iter"] = ii log['logu'] = logu log['logv'] = logv log['alpha'] = alpha + reg * nx.log(u) @@ -1048,13 +1140,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. - The function solves the following optimization problem: - .. math:: \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) @@ -1064,16 +1154,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, \gamma &\geq 0 where : - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - - + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights + (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization - proposed in :ref:`[10] ` and the log scaling proposed in :ref:`[9] ` algorithm 3.2 - + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` and the log scaling + proposed in :ref:`[9] ` algorithm 3.2 Parameters ---------- @@ -1086,7 +1176,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, reg : float Regularization term >0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` + for log scaling warmstart : tuple of vectors if given then starting values for alpha and beta log scalings numItermax : int, optional @@ -1101,6 +1192,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1108,10 +1201,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters - Examples -------- - >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] @@ -1123,19 +1214,19 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, .. _references-sinkhorn-epsilon-scaling: References ---------- + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - """ a, b, M = list_to_array(a, b, M) @@ -1155,7 +1246,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numItermin = 35 numItermax = max(numItermin, numItermax) # ensure that last velue is exact - cpt = 0 + ii = 0 if log: log = {'err': []} @@ -1170,12 +1261,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def get_reg(n): # exponential decreasing return (epsilon0 - reg) * np.exp(-n) + reg - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): - regi = get_reg(cpt) + regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, @@ -1185,10 +1274,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, alpha = logi['alpha'] beta = logi['beta'] - if cpt >= numItermax: - loop = False - - if cpt % (print_period) == 0: # spsion nearly converged + if ii % (print_period) == 0: # spsion nearly converged # we can speed up the process by checking for the error only all # the 10th iterations transp = G @@ -1197,19 +1283,22 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, log['err'].append(err) if verbose: - if cpt % (print_period * 10) == 0: + if ii % (print_period * 10) == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - if err <= stopThr and cpt > numItermin: - loop = False + print('{:5d}|{:8e}|'.format(ii, err)) - cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) + if err <= stopThr and ii > numItermin: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['alpha'] = alpha log['beta'] = beta log['warmstart'] = (log['alpha'], log['beta']) + log['niter'] = ii return G, log else: return G @@ -1245,7 +1334,7 @@ def projC(gamma, q): def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, **kwargs): + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -1255,11 +1344,16 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1270,7 +1364,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, reg : float Regularization term > 0 method : str (optional) - method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1281,6 +1375,8 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1295,7 +1391,9 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1303,18 +1401,24 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, return barycenter_sinkhorn(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return barycenter_stabilized(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_sinkhorn_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -1324,11 +1428,15 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[3]`. Parameters ---------- @@ -1348,6 +1456,8 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1362,7 +1472,9 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1378,43 +1490,109 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, if log: log = {'err': []} - # M = M/np.median(M) # suggested by G. Peyre K = nx.exp(-M / reg) - cpt = 0 err = 1 UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) u = (geometricMean(UKv) / UKv.T).T - while (err > stopThr and cpt < numItermax): - cpt = cpt + 1 + for ii in range(numItermax): + UKv = u * nx.dot(K, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv - if cpt % 10 == 1: + if ii % 10 == 1: err = nx.sum(nx.std(UKv, axis=1)) # log and verbose print if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii return geometricBar(weights, UKv), log else: return geometricBar(weights, UKv) +def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic wasserstein barycenter in log-domain + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar = log_bar + weights[k] * log_KU[:, k] + + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: @@ -1424,11 +1602,15 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1439,7 +1621,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, reg : float Regularization term > 0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1450,6 +1633,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1464,7 +1649,9 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1486,19 +1673,18 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, K = nx.exp(-M / reg) - cpt = 0 err = 1. alpha = nx.zeros((dim,), type_as=M) beta = nx.zeros((dim,), type_as=M) q = nx.ones((dim,), type_as=M) / dim - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): qprev = q Kv = nx.dot(K, v) - u = A / (Kv + 1e-16) + u = A / Kv Ktu = nx.dot(K.T, u) q = geometricBar(weights, Ktu) Q = q[:, None] - v = Q / (Ktu + 1e-16) + v = Q / Ktu absorbing = False if nx.any(u > tau) or nx.any(v > tau): absorbing = True @@ -1512,40 +1698,244 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % ii) q = qprev break - if (cpt % 10 == 0 and not absorbing) or cpt == 0: + if (ii % 10 == 0 and not absorbing) or ii == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(u * Kv - A)) if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 50 == 0: + if ii % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) - cpt += 1 - if err > stopThr: - warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + - "Try a larger entropy `reg`" + - "Or a larger absorption threshold `tau`.") + else: + if warn: + warnings.warn("Stabilized Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`.") if log: - log['niter'] = cpt - log['logu'] = nx.log(u + 1e-16) - log['logv'] = nx.log(v + 1e-16) + log['niter'] = ii + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) return q, log else: return q -def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` - where :math:`\mathbf{A}` is a collection of 2D images. +def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): + r"""Compute the debiased Sinkhorn barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence + (see :py:func:`ot.bregman.emirical_sinkhorn_divergence`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + + The algorithm used for solving the problem is the debiased Sinkhorn + algorithm as proposed in :ref:`[37] ` + + Parameters + ---------- + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + .. _references-sinkhorn-debiased: + References + ---------- + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _barycenter_debiased(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_debiased_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the debiased sinkhorn barycenter of distributions A. + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = nx.exp(-M / reg) + + err = 1 + + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + + u = (geometricMean(UKv) / UKv.T).T + c = nx.ones(A.shape[0], type_as=A) + bar = nx.ones(A.shape[0], type_as=A) + + for ii in range(numItermax): + bold = bar + UKv = nx.dot(K, A / nx.dot(K, u)) + bar = c * geometricBar(weights, UKv) + u = bar[:, None] / UKv + c = (c * bar / nx.dot(K, c)) ** 0.5 + + if ii % 10 == 9: + err = abs(bar - bold).max() / max(bar.max(), 1.) + + # log and verbose print + if log: + log['err'].append(err) + + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return bar, log + else: + return bar + + +def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, + warn=True): + r"""Compute the debiased sinkhorn barycenter in log domain. + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + c = nx.zeros(dim, type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar += weights[k] * log_KU[:, k] + log_bar += c + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + for _ in range(10): + c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, + warn=True, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. The function solves the following optimization problem: @@ -1554,11 +1944,14 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions + of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm + as proposed in :ref:`[21] ` Parameters ---------- @@ -1568,6 +1961,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, Regularization term >0 weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1578,6 +1973,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1591,9 +1988,36 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, + A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: + Efficient optimal transportation on geometric domains. ACM Transactions + on Graphics (TOG), 34(4), 66 + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. """ A = list_to_array(A) @@ -1608,65 +2032,373 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if log: log = {'err': []} - b = nx.zeros(A.shape[1:], type_as=A) + bar = nx.ones(A.shape[1:], type_as=A) + bar /= bar.sum() U = nx.ones(A.shape, type_as=A) - KV = nx.ones(A.shape, type_as=A) - - cpt = 0 + V = nx.ones(A.shape, type_as=A) err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions t = nx.linspace(0, 1, A.shape[1]) [Y, X] = nx.meshgrid(t, t) - xi1 = nx.exp(-(X - Y) ** 2 / reg) + K1 = nx.exp(-(X - Y) ** 2 / reg) t = nx.linspace(0, 1, A.shape[2]) [Y, X] = nx.meshgrid(t, t) - xi2 = nx.exp(-(X - Y) ** 2 / reg) - - def K(x): - return nx.dot(nx.dot(xi1, x), xi2) - - while (err > stopThr and cpt < numItermax): - - bold = b - cpt = cpt + 1 - - b = nx.zeros(A.shape[1:], type_as=A) - KV_cols = [] - for r in range(A.shape[0]): - KV_col_r = K(A[r, :, :] / nx.maximum(stabThr, K(U[r, :, :]))) - b += weights[r] * nx.log(nx.maximum(stabThr, U[r, :, :] * KV_col_r)) - KV_cols.append(KV_col_r) - KV = nx.stack(KV_cols) - b = nx.exp(b) - - U = nx.stack([ - b / nx.maximum(stabThr, KV[r, :, :]) - for r in range(A.shape[0]) - ]) - if cpt % 10 == 1: - err = nx.sum(nx.abs(bold - b)) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + if ii % 10 == 9: + err = (V * KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + log['U'] = U + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images in log-domain. + """ + + A = list_to_array(A) + + nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + n_hists, width, height = A.shape + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + G = log_bar[None, :, :] - log_KU + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", + numItermax=10000, stopThr=1e-3, + verbose=False, log=False, warn=True, + **kwargs): + r"""Compute the debiased sinkhorn barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn_debiased`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two + dimensions of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the debiased Sinkhorn scaling + algorithm as proposed in :ref:`[37] ` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` + reg : float + Regularization term >0 + weights : array-like, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : array-like, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-sinkhorn-debiased: + References + ---------- + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d_debiased(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-15, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + + nx = get_backend(A) + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + bar = nx.ones((width, height), type_as=A) + bar /= width * height + U = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) + c = nx.ones(A.shape[1:], type_as=A) + err = 1 + + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + K1 = nx.exp(-(X - Y) ** 2 / reg) + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + + for _ in range(10): + c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 + + if ii % 10 == 9: + err = (V * KU).std(axis=0).sum() # log and verbose print if log: log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii log['U'] = U - return b, log + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images in log-domain. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_bar, c = nx.zeros((2, width, height), type_as=A) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + log_bar += c + for _ in range(10): + c = 0.5 * (c + log_bar - convol_img(c)) + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr and ii > 20: + break + G = log_bar[None, :, :] - log_KU + else: - return b + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, - stopThr=1e-3, verbose=False, log=False): + stopThr=1e-3, verbose=False, log=False, warn=True): r""" Compute the unmixing of an observation with a given dictionary using Wasserstein distance @@ -1679,16 +2411,21 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, + its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (`dim_a`, `dim_a`) for OT data fitting - - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the + cost matrix (`dim_a`, `dim_a`) for OT data fitting + - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization + term and the cost matrix (`dim_prior`, `n_atoms`) regularization - :math:`\\alpha` weight data fitting and regularization - The optimization problem is solved following the algorithm described in :ref:`[4] ` + The optimization problem is solved following the algorithm described + in :ref:`[4] ` Parameters @@ -1717,7 +2454,8 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Print information along iterations log : bool, optional record log if True - + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1731,8 +2469,10 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, References ---------- - .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. - + .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, + Supervised planetary unmixing with optimal transport, Whorkshop + on Hyperspectral Image and Signal Processing : + Evolution in Remote Sensing (WHISPERS), 2016. """ a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) @@ -1747,12 +2487,11 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, old = h0 err = 1 - cpt = 0 # log = {'niter':0, 'all_err':[]} if log: log = {'err': []} - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): K = projC(K, a) K0 = projC(K0, h0) new = nx.sum(K0, axis=1) @@ -1770,22 +2509,27 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt = cpt + 1 - + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Unmixing algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii return nx.sum(K0, axis=1), log else: return nx.sum(K0, axis=1) def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - stopThr=1e-6, verbose=False, log=False, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] ` + stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs): + r'''Joint OT and proportion estimation for multi-source target shift as + proposed in :ref:`[27] ` The function solves the following optimization problem: @@ -1799,16 +2543,23 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, where : - :math:`\lambda_k` is the weight of `k`-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain + defined as in [p. 5, :ref:`27 `], its expected shape + is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source + domain and `C` is the number of classes - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in + [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` - The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + The problem consist in solving a Wasserstein barycenter problem to estimate + the proportions :math:`\mathbf{h}` in the target domain. The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. + with two sets of marginal constraints related to the unknown vector + :math:`\mathbf{h}` and uniform target distribution. Parameters ---------- @@ -1826,10 +2577,12 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Max number of iterations stopThr : float, optional Stop threshold on relative change in the barycenter (>0) - log : bool, optional - record log if True verbose : bool, optional (default=False) Controls the verbosity of the optimization algorithm + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1844,9 +2597,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, ---------- .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia - "Optimal transport for multi-source domain adaptation under target shift", - International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. ''' Xs = list_to_array(*Xs) @@ -1901,11 +2653,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # uniform target distribution a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) - cpt = 0 # iterations count err = 1 old_bary = nx.ones((nbclasses,), type_as=Xs[0]) - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): bary = nx.zeros((nbclasses,), type_as=Xs[0]) @@ -1923,21 +2674,27 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, K[d] = projR(K[d], new) err = nx.norm(bary - old_bary) - cpt = cpt + 1 + old_bary = bary if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") bary = bary / nx.sum(bary) if log: - log['niter'] = cpt + log['niter'] = ii log['M'] = M log['D1'] = D1 log['D2'] = D2 @@ -1949,7 +2706,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, **kwargs): + log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -1967,7 +2724,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -1988,7 +2746,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate full + cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. @@ -1996,6 +2756,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -2021,11 +2783,14 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' X_s, X_t = list_to_array(X_s, X_t) @@ -2100,7 +2865,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if err <= stopThr: break - + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: dict_log["u"] = f dict_log["v"] = g @@ -2111,15 +2880,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', else: M = dist(X_s, X_t, metric=metric) if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=True, **kwargs) return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=False, **kwargs) return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, isLazy=False, + batchSize=100, verbose=False, log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -2138,7 +2910,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2159,7 +2932,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate + full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. @@ -2167,6 +2942,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -2192,11 +2969,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling + Algorithms for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. ''' X_s, X_t = list_to_array(X_s, X_t) @@ -2211,11 +2994,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num if isLazy: if log: - f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, + batchSize=batchSize, + verbose=verbose, log=log, + warn=warn) else: - f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, + verbose=verbose, log=log, + warn=warn) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -2241,17 +3032,21 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num M = nx.from_numpy(M, type_as=a) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -2288,8 +3083,11 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli \gamma_b &\geq 0 where : - - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) + is the (`n_samples_a`, `n_samples_b`) metric cost matrix + (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2313,6 +3111,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -2334,17 +3134,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli References ---------- - .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative + Models with Sinkhorn Divergences, Proceedings of the Twenty-First + International Conference on Artficial Intelligence and Statistics, + (AISTATS) 21, 2018 ''' if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -2359,25 +3168,33 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, - maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, + restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, + verbose=False, log=False): r""" Screening Sinkhorn Algorithm for Regularized Optimal Transport - The function solves an approximate dual of Sinkhorn divergence :ref:`[2] ` which is written as the following optimization problem: + The function solves an approximate dual of Sinkhorn divergence :ref:`[2] + ` which is written as the following optimization problem: .. math:: @@ -2395,56 +3212,49 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} - The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` + The parameters `kappa` and `epsilon` are determined w.r.t the couple number + budget of points (`ns_budget`, `nt_budget`), see Equation (5) + in :ref:`[26] ` Parameters ---------- - a : array-like, shape=(ns,) + a: array-like, shape=(ns,) samples weights in the source domain - - b : array-like, shape=(nt,) + b: array-like, shape=(nt,) samples weights in the target domain - - M : array-like, shape=(ns, nt) + M: array-like, shape=(ns, nt) Cost matrix - - reg : `float` + reg: `float` Level of the entropy regularisation - - ns_budget : `int`, default=None + ns_budget: `int`, default=None Number budget of points to be kept in the source domain. If it is None then 50% of the source sample points will be kept - - nt_budget : `int`, default=None + nt_budget: `int`, default=None Number budget of points to be kept in the target domain. If it is None then 50% of the target sample points will be kept - - uniform : `bool`, default=False - If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` - + uniform: `bool`, default=False + If `True`, the source and target distribution are supposed to be uniform, + i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver using a restricted Sinkhorn algorithm with at most 5 iterations - - maxiter : `int`, default=10000 + maxiter: `int`, default=10000 Maximum number of iterations in LBFGS solver - - maxfun : `int`, default=10000 + maxfun: `int`, default=10000 Maximum number of function evaluations in LBFGS solver - - pgtol : `float`, default=1e-09 + pgtol: `float`, default=1e-09 Final objective function accuracy in LBFGS solver - - verbose : `bool`, default=False - If `True`, display informations about the cardinals of the active sets and the parameters kappa - and epsilon - + verbose: `bool`, default=False + If `True`, display informations about the cardinals of the active sets + and the parameters kappa and epsilon Dependency ---------- - To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) - in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears: + To gain more efficiency, screenkhorn needs to call the "Bottleneck" + package (https://pypi.org/project/Bottleneck/) + in the screening pre-processing step. If Bottleneck isn't installed, + the following error message appears: "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" @@ -2461,9 +3271,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res References ----------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). + Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 """ # check if bottleneck module exists @@ -2471,14 +3283,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res import bottleneck except ImportError: warnings.warn( - "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + "Bottleneck module is not installed. Install it from" + " https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received but screenkhorn is not compatible with JAX.") + raise TypeError("JAX arrays have been received but screenkhorn is not " + "compatible with JAX.") ns, nt = M.shape @@ -2582,7 +3396,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res if verbose: print("epsilon = %s\n" % epsilon) print("kappa = %s\n" % kappa) - print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel))) + print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' + % (sum(Isel), sum(Jsel))) # Ic, Jc: complementary of the active sets I and J Ic = ~Isel @@ -2638,13 +3453,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa - cpt = 1 - while cpt < 5: # 5 iterations + for _ in range(5): # 5 iterations K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v v0 = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, v0) + cst_u u0 = (kappa * a_I) / KIJ_u - cpt += 1 u0 = projection(u0, epsilon / kappa) v0 = projection(v0, epsilon * kappa) @@ -2655,15 +3468,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res def restricted_sinkhorn(usc, vsc, max_iter=5): """ - Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26]) + Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B) """ - cpt = 1 - while cpt < max_iter: + for _ in range(max_iter): K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v vsc = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, vsc) + cst_u usc = (kappa * a_I) / KIJ_u - cpt += 1 usc = projection(usc, epsilon / kappa) vsc = projection(vsc, epsilon * kappa) -- cgit v1.2.3