summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py808
1 files changed, 656 insertions, 152 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 50ec03c..d516dfc 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -9,51 +9,56 @@ Regularized Unbalanced OT
from __future__ import division
import warnings
import numpy as np
+from scipy.special import logsumexp
+
# from .utils import unif, dist
-def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the unbalanced entropic regularization optimal transport problem and return the loss
+ Solve the unbalanced entropic regularization optimal transport problem
+ and return the OT plan
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns, nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -62,10 +67,16 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -82,83 +93,96 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_stabilized_unbalanced:
+ Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
+ Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
+ raise ValueError("Unknown method '%s'." % method)
- return sink()
-
-def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
- numItermax=1000, stopThr=1e-9, verbose=False,
+def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
+ numItermax=1000, stopThr=1e-6, verbose=False,
log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -171,10 +195,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
log : dict
- log dictionary return only if log==True in parameters
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -191,64 +215,70 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
-
- if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
- warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
- else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
-
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
b = b[:, None]
-
- return sink()
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method %s.' % method)
-def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
@@ -256,21 +286,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -279,11 +309,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -298,9 +333,13 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
References
----------
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
@@ -313,12 +352,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- n_a, n_b = M.shape
+ dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(n_a, dtype=np.float64) / n_a
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
if len(b) == 0:
- b = np.ones(n_b, dtype=np.float64) / n_b
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -331,21 +370,19 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((n_a, 1)) / n_a
- v = np.ones((n_b, n_hists)) / n_b
- a = a.reshape(n_a, 1)
+ u = np.ones((dim_a, 1)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
else:
- u = np.ones(n_a) / n_a
- v = np.ones(n_b) / n_b
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
- # print(reg)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- # print(np.min(K))
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
@@ -364,15 +401,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -380,10 +418,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ cpt += 1
+
if log:
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -400,9 +439,224 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
return u[:, None] * K * v[None, :]
-def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False,
+ **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport
+ problem and return the loss
+
+ The function solves the following optimization problem using log-domain
+ stabilization as proposed in [10]:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+ References
+ ----------
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
+ if len(b) == 0:
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
+ else:
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim_a)
+ beta = np.zeros(dim_b)
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+
+ if n_hists:
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ if n_hists:
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ else:
+ alpha = alpha + reg * np.log(np.max(u))
+ beta = beta + reg * np.log(np.max(v))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
+ 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if n_hists:
+ logu = alpha[:, None] / reg + np.log(u)
+ logv = beta[:, None] / reg + np.log(v)
+ else:
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ if log:
+ log['logu'] = logu
+ log['logv'] = logv
+ if n_hists: # return only loss
+ res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
+ logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
+ res = np.exp(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ if log:
+ return ot_matrix, log
+ else:
+ return ot_matrix
+
+
+def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
The function solves the following optimization problem:
@@ -411,28 +665,35 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- - alpha is the marginal relaxation hyperparameter
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of
+ matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
Parameters
----------
- A : np.ndarray (d,n)
- n training distributions a_i of size d
- M : np.ndarray (d,d)
- loss matrix for OT
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m : float
Marginal relaxation term > 0
- weights : np.ndarray (n,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ tau : float
+ Stabilization threshold for log domain absorption.
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -441,7 +702,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -450,12 +711,165 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
+ G. (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
"""
- p, n_hists = A.shape
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ fi = reg_m / (reg_m + reg)
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ q = (Ktu ** (1 - fi)) * f_beta
+ q = q.dot(weights) ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(q - qprev).max() / max(abs(q).max(),
+ abs(qprev).max(), 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+
+ """
+ dim, n_hists = A.shape
if weights is None:
weights = np.ones(n_hists) / n_hists
else:
@@ -466,10 +880,10 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
K = np.exp(- M / reg)
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
- v = np.ones((p, n_hists)) / p
- u = np.ones((p, 1)) / p
+ v = np.ones((dim, n_hists)) / dim
+ u = np.ones((dim, 1)) / dim
cpt = 0
err = 1.
@@ -498,8 +912,11 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
- np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -511,8 +928,95 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
cpt += 1
if log:
log['niter'] = cpt
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
return q, log
else:
return q
+
+
+def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False, **kwargs):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return barycenter_unbalanced(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)