summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-07-19 17:04:14 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-08-28 15:36:37 +0200
commitcfdbbd21642c6082164b84db78c2ead07499a113 (patch)
tree11ac0f20c9d5881802ac348be86dff65e5db67d7 /ot/unbalanced.py
parentabfe183a49caaf74a07e595ac40920dae05a3c22 (diff)
remove square in convergence check
add unbalanced with stabilization add unbalanced tests with stabilization fix doctest examples add xvfb in travis remove explicit call xvfb in travis change alpha to reg_m minor flake8 remove redundant sink definitions + better doc and naming add stabilized unbalanced barycenter + add not converged warnings add test for stable barycenter add generic barycenter func + make method funcs private fix typo + add method test for barycenters fix doc examples + add xml to gitignore fix whitespace in example change logsumexp import - scipy deprecation warning fix doctest improve naming + add stable barycenter in bregman add test for stable bar + test the method arg in bregman
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py803
1 files changed, 653 insertions, 150 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 0f0692e..3f71d28 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -9,51 +9,56 @@ Regularized Unbalanced OT
from __future__ import division
import warnings
import numpy as np
+from scipy.special import logsumexp
+
# from .utils import unif, dist
-def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the unbalanced entropic regularization optimal transport problem and return the loss
+ Solve the unbalanced entropic regularization optimal transport problem
+ and return the OT plan
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns, nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -62,10 +67,16 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -82,83 +93,96 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_stabilized_unbalanced:
+ Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
+ Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
-def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
- numItermax=1000, stopThr=1e-9, verbose=False,
+def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
+ numItermax=1000, stopThr=1e-6, verbose=False,
log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -171,10 +195,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
log : dict
- log dictionary return only if log==True in parameters
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -191,64 +215,70 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
-
- if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
- warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
- else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
-
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
b = b[:, None]
-
- return sink()
+ if method.lower() == 'sinkhorn':
+ return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method %s.' % method)
-def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
@@ -256,16 +286,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
numItermax : int, optional
Max number of iterations
@@ -279,11 +309,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -291,16 +326,20 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
- >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
+ >>> ot.unbalanced._sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
array([[0.51122823, 0.18807035],
[0.18807035, 0.51122823]])
References
----------
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
@@ -313,12 +352,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- n_a, n_b = M.shape
+ dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(n_a, dtype=np.float64) / n_a
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
if len(b) == 0:
- b = np.ones(n_b, dtype=np.float64) / n_b
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -331,21 +370,19 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((n_a, 1)) / n_a
- v = np.ones((n_b, n_hists)) / n_b
- a = a.reshape(n_a, 1)
+ u = np.ones((dim_a, 1)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
else:
- u = np.ones(n_a) / n_a
- v = np.ones(n_b) / n_b
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
- # print(reg)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- # print(np.min(K))
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
@@ -371,8 +408,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -383,8 +421,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
cpt += 1
if log:
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -401,9 +439,224 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
return u[:, None] * K * v[None, :]
-def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
+def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False,
+ **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport
+ problem and return the loss
+
+ The function solves the following optimization problem using log-domain
+ stabilization as proposed in [10]:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced._sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+ References
+ ----------
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
+ if len(b) == 0:
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
+ else:
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim_a)
+ beta = np.zeros(dim_b)
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+
+ if n_hists:
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ if n_hists:
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ else:
+ alpha = alpha + reg * np.log(np.max(u))
+ beta = beta + reg * np.log(np.max(v))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
+ 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if n_hists:
+ logu = alpha[:, None] / reg + np.log(u)
+ logv = beta[:, None] / reg + np.log(v)
+ else:
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ if log:
+ log['logu'] = logu
+ log['logv'] = logv
+ if n_hists: # return only loss
+ res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
+ logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
+ res = np.exp(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ if log:
+ return ot_matrix, log
+ else:
+ return ot_matrix
+
+
+def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
The function solves the following optimization problem:
@@ -412,28 +665,184 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- - alpha is the marginal relaxation hyperparameter
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of
+ matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
Parameters
----------
- A : np.ndarray (d,n)
- n training distributions a_i of size d
- M : np.ndarray (d,d)
- loss matrix for OT
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m : float
Marginal relaxation term > 0
- weights : np.ndarray (n,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ tau : float
+ Stabilization threshold for log domain absorption.
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
+ G. (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+
+ """
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ fi = reg_m / (reg_m + reg)
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ q = (Ktu ** (1 - fi)) * f_beta
+ q = q.dot(weights) ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(q - qprev).max() / max(abs(q).max(),
+ abs(qprev).max(), 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def _barycenter_unbalanced(A, M, reg, reg_m, weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -442,7 +851,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -451,12 +860,16 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
"""
- p, n_hists = A.shape
+ dim, n_hists = A.shape
if weights is None:
weights = np.ones(n_hists) / n_hists
else:
@@ -467,10 +880,10 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
K = np.exp(- M / reg)
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
- v = np.ones((p, n_hists)) / p
- u = np.ones((p, 1)) / p
+ v = np.ones((dim, n_hists)) / dim
+ u = np.ones((dim, 1)) / dim
cpt = 0
err = 1.
@@ -499,8 +912,11 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
- np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -512,8 +928,95 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
cpt += 1
if log:
log['niter'] = cpt
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
return q, log
else:
return q
+
+
+def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False, **kwargs):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _barycenter_unbalanced(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return _barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return _barycenter_unbalanced(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)