summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py272
1 files changed, 249 insertions, 23 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b59ee1b..2aa76ff 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -64,7 +64,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
solutions. Note that the greedy version of the sinkhorn
:py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a
- fast approximation of the Sinkhorn problem.
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :any:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
@@ -79,8 +82,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
+ those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -118,6 +122,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
@@ -134,6 +139,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log)
@@ -182,7 +191,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ :any:`ot.bregman.sinkhorn_log` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
@@ -190,7 +199,10 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
solutions. Note that the greedy version of the sinkhorn
:any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :any:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
----------
@@ -204,7 +216,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -230,7 +243,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
- array([0.26894142])
+ 0.26894142136999516
.. _references-sinkhorn2:
@@ -243,7 +256,11 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
- .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
+ algorithms for optimal transport via Sinkhorn iteration, Advances in Neural
+ Information Processing Systems (NIPS) 31, 2017
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
@@ -257,20 +274,45 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
"""
- b = list_to_array(b)
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
if len(b.shape) < 2:
- b = b[:, None]
+ if method.lower() == 'sinkhorn':
+ res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+ if log:
+ return nx.sum(M * res[0]), res[1]
+ else:
+ return nx.sum(M * res)
- if method.lower() == 'sinkhorn':
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- elif method.lower() == 'sinkhorn_stabilized':
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
else:
- raise ValueError("Unknown method '%s'." % method)
+
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
@@ -361,7 +403,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# init data
dim_a = len(a)
- dim_b = len(b)
+ dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -438,6 +480,191 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
+def sinkhorn_log(a, b, M, reg, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the entropic regularization optimal transport problem in log space
+ and return the OT matrix
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm :ref:`[2] <references-sinkhorn-knopp>` with the
+ implementation from :ref:`[34] <references-sinkhorn-knopp>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (dim_a,)
+ samples weights in the source domain
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gamma : array-like, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ .. _references-sinkhorn-log:
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+
+ if len(a) == 0:
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
+ if len(b) == 0:
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
+
+ # init data
+ dim_a = len(a)
+ dim_b = b.shape[0]
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if n_hists: # we do not want to use tensors sor we do a loop
+
+ lst_loss = []
+ lst_u = []
+ lst_v = []
+
+ for k in range(n_hists):
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+
+ if log:
+ lst_loss.append(nx.sum(M * res[0]))
+ lst_u.append(res[1]['log_u'])
+ lst_v.append(res[1]['log_v'])
+ else:
+ lst_loss.append(nx.sum(M * res))
+ res = nx.stack(lst_loss)
+ if log:
+ log = {'log_u': nx.stack(lst_u, 1),
+ 'log_v': nx.stack(lst_v, 1), }
+ log['u'] = nx.exp(log['log_u'])
+ log['v'] = nx.exp(log['log_v'])
+ return res, log
+ else:
+ return res
+
+ else:
+
+ if log:
+ log = {'err': []}
+
+ Mr = M / (-reg)
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+
+ def get_logT(u, v):
+ if n_hists:
+ return Mr[:, :, None] + u + v
+ else:
+ return Mr + u[:, None] + v[None, :]
+
+ loga = nx.log(a)
+ logb = nx.log(b)
+
+ cpt = 0
+ err = 1
+ while (err > stopThr and cpt < numItermax):
+
+ v = logb - nx.logsumexp(Mr + u[:, None], 0)
+ u = loga - nx.logsumexp(Mr + v[None, :], 1)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
+ err = nx.norm(tmp2 - b) # violation of marginal
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if log:
+ log['log_u'] = u
+ log['log_v'] = v
+ log['u'] = nx.exp(u)
+ log['v'] = nx.exp(v)
+
+ return nx.exp(get_logT(u, v)), log
+
+ else:
+ return nx.exp(get_logT(u, v))
+
+
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
log=False):
r"""
@@ -1881,8 +2108,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return (f, g)
else:
- M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
- M = nx.from_numpy(M, type_as=a)
+ M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
return pi, log
@@ -2102,7 +2328,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
>>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
- array([1.499...])
+ 1.499887176049052
References