summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py2811
1 files changed, 2042 insertions, 769 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index f1f8437..cce52e2 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,70 +7,104 @@ Bregman projections solvers for entropic regularized OT
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-# Hicham Janati <hicham.janati@inria.fr>
+# Hicham Janati <hicham.janati100@gmail.com>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
-import numpy as np
import warnings
-from .utils import unif, dist
+
+import numpy as np
from scipy.optimize import fmin_l_bfgs_b
+from ot.utils import unif, dist, list_to_array
+from .backend import get_backend
+
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=True,
+ **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma^T 1= b
+ \gamma &\geq 0
- \gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn>`
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
+ those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
-
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -86,102 +120,152 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
-
+ .. _references-sinkhorn:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
+ A., & Peyré, G. (2019, April). Interpolating between optimal transport
+ and MMD using Sinkhorn divergences. In The 22nd International Conference
+ on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
+ :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling
+ :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
"""
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
+ **kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn,
+ **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
return sinkhorn_epsilon_scaling(a, b, M, reg,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn,
+ **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma^T 1= b
+ \gamma &\geq 0
- \gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
+
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- W : (n_hists) ndarray or float
+ W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
+
Examples
--------
@@ -190,99 +274,142 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
- array([0.26894142])
-
+ 0.26894142136999516
+ .. _references-sinkhorn2:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
- [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
+ algorithms for optimal transport via Sinkhorn iteration,
+ Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
+ Trouvé, A., & Peyré, G. (2019, April).
+ Interpolating between optimal transport and MMD using Sinkhorn
+ divergences. In The 22nd International Conference on Artificial
+ Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.greenkhorn : Greenkhorn [21]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
-
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>`
+ ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
+ :ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>`
"""
- b = np.asarray(b, dtype=np.float64)
+
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
if len(b.shape) < 2:
- b = b[:, None]
- if method.lower() == 'sinkhorn':
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- elif method.lower() == 'sinkhorn_stabilized':
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ if method.lower() == 'sinkhorn':
+ res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+ if log:
+ return nx.sum(M * res[0]), res[1]
+ else:
+ return nx.sum(M * res)
+
else:
- raise ValueError("Unknown method '%s'." % method)
+
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True,
+ **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp
+ matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -299,10 +426,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
[0.13447071, 0.36552929]])
+ .. _references-sinkhorn-knopp:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
See Also
@@ -312,18 +442,18 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
- dim_b = len(b)
+ dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -336,66 +466,64 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # print(reg)
-
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
-
- # print(np.min(K))
- tmp2 = np.empty(b.shape, dtype=M.dtype)
+ K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
- cpt = 0
+
err = 1
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
uprev = u
vprev = v
+ KtransposeU = nx.dot(K.T, u)
+ v = b / KtransposeU
+ u = 1. / nx.dot(Kp, v)
- KtransposeU = np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.dot(Kp, v)
-
- if (np.any(KtransposeU == 0)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(KtransposeU == 0)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Warning: numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
- if cpt % 10 == 0:
+ if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
+ tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
- np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b) # violation of marginal
+ tmp2 = nx.einsum('i,ij,j->j', u, K, v)
+ err = nx.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log['niter'] = ii
log['u'] = u
log['v'] = v
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -409,58 +537,259 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False):
+def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, **kwargs):
r"""
- Solve the entropic regularization optimal transport problem and return the OT matrix
+ Solve the entropic regularization optimal transport problem in log space
+ and return the OT matrix
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma^T \mathbf{1} &= \mathbf{b}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm :ref:`[2] <references-sinkhorn-log>` with the
+ implementation from :ref:`[34] <references-sinkhorn-log>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (dim_a,)
+ samples weights in the source domain
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+ Returns
+ -------
+ gamma : array-like, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ .. _references-sinkhorn-log:
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
+ Trouvé, A., & Peyré, G. (2019, April). Interpolating between
+ optimal transport and MMD using Sinkhorn divergences. In The
+ 22nd International Conference on Artificial Intelligence and
+ Statistics (pp. 2681-2690). PMLR.
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+
+ if len(a) == 0:
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
+ if len(b) == 0:
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
+
+ # init data
+ dim_a = len(a)
+ dim_b = b.shape[0]
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
- The algorithm used is based on the paper
+ if n_hists: # we do not want to use tensors sor we do a loop
- Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
- by Jason Altschuler, Jonathan Weed, Philippe Rigollet
- appeared at NIPS 2017
+ lst_loss = []
+ lst_u = []
+ lst_v = []
- which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
+ for k in range(n_hists):
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+
+ if log:
+ lst_loss.append(nx.sum(M * res[0]))
+ lst_u.append(res[1]['log_u'])
+ lst_v.append(res[1]['log_v'])
+ else:
+ lst_loss.append(nx.sum(M * res))
+ res = nx.stack(lst_loss)
+ if log:
+ log = {'log_u': nx.stack(lst_u, 1),
+ 'log_v': nx.stack(lst_v, 1), }
+ log['u'] = nx.exp(log['log_u'])
+ log['v'] = nx.exp(log['log_v'])
+ return res, log
+ else:
+ return res
+
+ else:
+
+ if log:
+ log = {'err': []}
+
+ Mr = - M / reg
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+
+ def get_logT(u, v):
+ if n_hists:
+ return Mr[:, :, None] + u + v
+ else:
+ return Mr + u[:, None] + v[None, :]
+
+ loga = nx.log(a)
+ logb = nx.log(b)
+
+ err = 1
+ for ii in range(numItermax):
+
+ v = logb - nx.logsumexp(Mr + u[:, None], 0)
+ u = loga - nx.logsumexp(Mr + v[None, :], 1)
+
+ if ii % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
+ err = nx.norm(tmp2 - b) # violation of marginal
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+
+ if log:
+ log['niter'] = ii
+ log['log_u'] = u
+ log['log_v'] = v
+ log['u'] = nx.exp(u)
+ log['v'] = nx.exp(v)
+
+ return nx.exp(get_logT(u, v)), log
+
+ else:
+ return nx.exp(get_logT(u, v))
+
+
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
+ log=False, warn=True):
+ r"""
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>`
+ which is a stochastic version of the Sinkhorn-Knopp
+ algorithm :ref:`[2] <references-greenkhorn>`
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
-
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -477,11 +806,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
[0.13447071, 0.36552929]])
+ .. _references-greenkhorn:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
+
+ .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time
+ approximation algorithms for optimal transport via Sinkhorn
+ iteration, Advances in Neural Information Processing
+ Systems (NIPS) 31, 2017
See Also
@@ -491,68 +827,70 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+ if nx.__name__ == "jax":
+ raise TypeError("JAX arrays have been received. Greenkhorn is not "
+ "compatible with JAX")
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
dim_a = a.shape[0]
dim_b = b.shape[0]
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty_like(M)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
- u = np.full(dim_a, 1. / dim_a)
- v = np.full(dim_b, 1. / dim_b)
- G = u[:, np.newaxis] * K * v[np.newaxis, :]
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ G = u[:, None] * K * v[None, :]
- viol = G.sum(1) - a
- viol_2 = G.sum(0) - b
+ viol = nx.sum(G, axis=1) - a
+ viol_2 = nx.sum(G, axis=0) - b
stopThr_val = 1
-
if log:
log = dict()
log['u'] = u
log['v'] = v
- for i in range(numItermax):
- i_1 = np.argmax(np.abs(viol))
- i_2 = np.argmax(np.abs(viol_2))
- m_viol_1 = np.abs(viol[i_1])
- m_viol_2 = np.abs(viol_2[i_2])
- stopThr_val = np.maximum(m_viol_1, m_viol_2)
+ for ii in range(numItermax):
+ i_1 = nx.argmax(nx.abs(viol))
+ i_2 = nx.argmax(nx.abs(viol_2))
+ m_viol_1 = nx.abs(viol[i_1])
+ m_viol_2 = nx.abs(viol_2[i_2])
+ stopThr_val = nx.maximum(m_viol_1, m_viol_2)
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- u[i_1] = a[i_1] / (K[i_1, :].dot(v))
- G[i_1, :] = u[i_1] * K[i_1, :] * v
-
- viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
- viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)
+ new_u = a[i_1] / (K[i_1, :].dot(v))
+ G[i_1, :] = new_u * K[i_1, :] * v
+ viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1]
+ viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
+ u[i_1] = new_u
else:
old_v = v[i_2]
- v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
- G[:, i_2] = u * K[:, i_2] * v[i_2]
+ new_v = b[i_2] / (K[:, i_2].T.dot(u))
+ G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
- viol += (-old_v + v[i_2]) * K[:, i_2] * u
- viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
-
- # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+ viol += (-old_v + new_v) * K[:, i_2] * u
+ viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2]
+ v[i_2] = new_v
if stopThr_val <= stopThr:
break
else:
- print('Warning: Algorithm did not converge')
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log["n_iter"] = ii
log['u'] = u
log['v'] = v
@@ -564,58 +902,66 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_ but with the log stabilization
- proposed in [10]_ an defined in [9]_ (Algo 3.1) .
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>`
+ but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in
+ :ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,)
+ b : array-like, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
- thershold for max value in u or v for log scaling
- warmstart : tible of vectors
- if given then sarting values for alpha an beta log scalings
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
+ for log scaling
+ warmstart : table of vectors
+ if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -632,14 +978,21 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
[0.13447071, 0.36552929]])
+ .. _references-sinkhorn-stabilized:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
See Also
@@ -649,19 +1002,19 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# test if multiple target
if len(b.shape) > 1:
n_hists = b.shape[1]
- a = a[:, np.newaxis]
+ a = a[:, None]
else:
n_hists = 0
@@ -669,123 +1022,123 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
dim_a = len(a)
dim_b = len(b)
- cpt = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M)
+ u /= dim_a
+ v /= dim_b
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1))
+ return nx.exp(-(M - alpha.reshape((dim_a, 1))
- beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
- / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b))))
-
- # print(np.min(K))
+ return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
+ / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b))))
K = get_K(alpha, beta)
transp = K
- loop = 1
- cpt = 0
err = 1
- while loop:
+ for ii in range(numItermax):
uprev = u
vprev = v
# sinkhorn update
- v = b / (np.dot(K.T, u) + 1e-16)
- u = a / (np.dot(K, v) + 1e-16)
+ v = b / (nx.dot(K.T, u))
+ u = a / (nx.dot(K, v))
# remove numerical problems and store them in K
- if np.abs(u).max() > tau or np.abs(v).max() > tau:
+ if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * \
- np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
+ alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
- alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
+ alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
- u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
K = get_K(alpha, beta)
- if cpt % print_period == 0:
+ if ii % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- err_u = abs(u - uprev).max()
- err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max()
- err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err_u = nx.max(nx.abs(u - uprev))
+ err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0)
+ err_v = nx.max(nx.abs(v - vprev))
+ err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0)
err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))
+ err = nx.norm(nx.sum(transp, axis=0) - b)
if log:
log['err'].append(err)
if verbose:
- if cpt % (print_period * 20) == 0:
+ if ii % (print_period * 20) == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr:
- loop = False
-
- if cpt >= numItermax:
- loop = False
+ break
- if np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
-
- cpt = cpt + 1
-
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
- logu = alpha / reg + np.log(u)
- logv = beta / reg + np.log(v)
+ logu = alpha / reg + nx.log(u)
+ logv = beta / reg + nx.log(v)
+ log["n_iter"] = ii
log['logu'] = logu
log['logv'] = logv
- log['alpha'] = alpha + reg * np.log(u)
- log['beta'] = beta + reg * np.log(v)
+ log['alpha'] = alpha + reg * nx.log(u)
+ log['beta'] = beta + reg * nx.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
if n_hists:
- res = np.zeros((n_hists))
- for i in range(n_hists):
- res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ res = nx.stack([
+ nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ for i in range(n_hists)
+ ])
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
if n_hists:
- res = np.zeros((n_hists))
- for i in range(n_hists):
- res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ res = nx.stack([
+ nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ for i in range(n_hists)
+ ])
return res
else:
return get_Gamma(alpha, beta, u, v)
@@ -794,70 +1147,73 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numInnerItermax=100, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=10,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
-
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
- where :
+ \gamma &\geq 0
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ where :
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_ but with the log stabilization
- proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2
-
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>`
+ but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling
+ proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,)
+ b : array-like, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
- thershold for max value in u or v for log scaling
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}`
+ for log scaling
warmstart : tuple of vectors
- if given then sarting values for alpha an beta log scalings
+ if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
- Max number of iterationsin the inner slog stabilized sinkhorn
+ Max number of iterations in the inner slog stabilized sinkhorn
epsilon0 : int, optional
first epsilon regularization value (then exponential decrease to reg)
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
-
Examples
--------
-
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
@@ -866,29 +1222,32 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
-
+ .. _references-sinkhorn-epsilon-scaling:
References
----------
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
-
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# init data
dim_a = len(a)
@@ -898,14 +1257,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numItermin = 35
numItermax = max(numItermin, numItermax) # ensure that last velue is exact
- cpt = 0
+ ii = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
@@ -913,12 +1272,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def get_reg(n): # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg
- loop = 1
- cpt = 0
err = 1
- while loop:
+ for ii in range(numItermax):
- regi = get_reg(cpt)
+ regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
numItermax=numInnerItermax, stopThr=1e-9,
@@ -928,33 +1285,31 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
alpha = logi['alpha']
beta = logi['beta']
- if cpt >= numItermax:
- loop = False
-
- if cpt % (print_period) == 0: # spsion nearly converged
+ if ii % (print_period) == 0: # spsion nearly converged
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = np.linalg.norm(
- (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
- if cpt % (print_period * 10) == 0:
- print(
- '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- if err <= stopThr and cpt > numItermin:
- loop = False
+ if ii % (print_period * 10) == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
- cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
+ if err <= stopThr and ii > numItermin:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
log['alpha'] = alpha
log['beta'] = beta
log['warmstart'] = (log['alpha'], log['beta'])
+ log['niter'] = ii
return G, log
else:
return G
@@ -962,76 +1317,94 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
+ weights, alldistribT = list_to_array(weights, alldistribT)
+ nx = get_backend(weights, alldistribT)
assert (len(weights) == alldistribT.shape[1])
- return np.exp(np.dot(np.log(alldistribT), weights.T))
+ return nx.exp(nx.dot(nx.log(alldistribT), weights.T))
def geometricMean(alldistribT):
"""return the geometric mean of distributions"""
- return np.exp(np.mean(np.log(alldistribT), axis=1))
+ alldistribT = list_to_array(alldistribT)
+ nx = get_backend(alldistribT)
+ return nx.exp(nx.mean(nx.log(alldistribT), axis=1))
def projR(gamma, p):
"""return the KL projection on the row constrints """
- return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T
+ gamma, p = list_to_array(gamma, p)
+ nx = get_backend(gamma, p)
+ return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T
def projC(gamma, q):
"""return the KL projection on the column constrints """
- return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
+ gamma, q = list_to_array(gamma, q)
+ nx = get_backend(gamma, q)
+ return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)
def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
- stopThr=1e-4, verbose=False, log=False, **kwargs):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`.
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
+ algorithm as proposed in :ref:`[3] <references-barycenter>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
method : str (optional)
- method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log'
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1039,232 +1412,327 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
return barycenter_sinkhorn(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_stabilized(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn, **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _barycenter_sinkhorn_log(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[3]<references-barycenter-sinkhorn>`.
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter-sinkhorn:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
if weights is None:
- weights = np.ones(A.shape[1]) / A.shape[1]
+ weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
- # M = M/np.median(M) # suggested by G. Peyre
- K = np.exp(-M / reg)
+ K = nx.exp(-M / reg)
- cpt = 0
err = 1
- UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T)
+ UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
+
u = (geometricMean(UKv) / UKv.T).T
- while (err > stopThr and cpt < numItermax):
- cpt = cpt + 1
- UKv = u * np.dot(K, np.divide(A, np.dot(K, u)))
+ for ii in range(numItermax):
+
+ UKv = u * nx.dot(K, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
- if cpt % 10 == 1:
- err = np.sum(np.std(UKv, axis=1))
+ if ii % 10 == 1:
+ err = nx.sum(nx.std(UKv, axis=1))
# log and verbose print
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
return geometricBar(weights, UKv), log
else:
return geometricBar(weights, UKv)
+def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic wasserstein barycenter in log-domain
+ """
+
+ A, M = list_to_array(A, M)
+ dim, n_hists = A.shape
+
+ nx = get_backend(A, M)
+
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ if weights is None:
+ weights = nx.ones(n_hists, type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ M = - M / reg
+ logA = nx.log(A + 1e-15)
+ log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros(dim, type_as=A)
+ for k in range(n_hists):
+ f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
+ log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
+ log_bar = log_bar + weights[k] * log_KU[:, k]
+
+ if ii % 10 == 1:
+ err = nx.exp(G + log_KU).std(axis=1).sum()
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if err < stopThr:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+
+ G = log_bar[:, None] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
- with stabilization.
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization.
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
+ algorithm as proposed in :ref:`[3] <references-barycenter-stabilized>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
tau : float
- thershold for max value in u or v for log scaling
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
+ for log scaling
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter-stabilized:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones((n_hists,), type_as=M) / n_hists
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
- u = np.ones((dim, n_hists)) / dim
- v = np.ones((dim, n_hists)) / dim
+ u = nx.ones((dim, n_hists), type_as=M) / dim
+ v = nx.ones((dim, n_hists), type_as=M) / dim
- # print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
- cpt = 0
err = 1.
- alpha = np.zeros(dim)
- beta = np.zeros(dim)
- q = np.ones(dim) / dim
- while (err > stopThr and cpt < numItermax):
+ alpha = nx.zeros((dim,), type_as=M)
+ beta = nx.zeros((dim,), type_as=M)
+ q = nx.ones((dim,), type_as=M) / dim
+ for ii in range(numItermax):
qprev = q
- Kv = K.dot(v)
- u = A / (Kv + 1e-16)
- Ktu = K.T.dot(u)
+ Kv = nx.dot(K, v)
+ u = A / Kv
+ Ktu = nx.dot(K.T, u)
q = geometricBar(weights, Ktu)
Q = q[:, None]
- v = Q / (Ktu + 1e-16)
+ v = Q / Ktu
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha += reg * nx.log(nx.max(u, 1))
+ beta += reg * nx.log(nx.max(v, 1))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(tuple(v.shape), type_as=v)
+ Kv = nx.dot(K, v)
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration %s' % cpt)
+ warnings.warn('Numerical errors at iteration %s' % ii)
q = qprev
break
- if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ if (ii % 10 == 0 and not absorbing) or ii == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(u * Kv - A).max()
+ err = nx.max(nx.abs(u * Kv - A))
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 50 == 0:
+ if ii % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
- cpt += 1
- if err > stopThr:
- warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
- "Try a larger entropy `reg`" +
- "Or a larger absorption threshold `tau`.")
+ else:
+ if warn:
+ warnings.warn("Stabilized Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['logu'] = np.log(u + 1e-16)
log['logv'] = np.log(v + 1e-16)
return q, log
@@ -1272,157 +1740,717 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
return q
-def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
- stopThr=1e-9, stabThr=1e-30, verbose=False,
- log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
- where A is a collection of 2D images.
+def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
+ r"""Compute the debiased Sinkhorn barycenter of distributions A
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
- - reg is the regularization strength scalar value
+ - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence
+ (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+ The algorithm used for solving the problem is the debiased Sinkhorn
+ algorithm as proposed in :ref:`[37] <references-barycenter-debiased>`
Parameters
----------
- A : ndarray, shape (n_hists, width, height)
- n distributions (2D images) of size width x height
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+
+ Returns
+ -------
+ a : (dim,) array-like
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-barycenter-debiased:
+ References
+ ----------
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+ Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _barycenter_debiased(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _barycenter_debiased_log(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the debiased sinkhorn barycenter of distributions A.
+ """
+
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
+ if weights is None:
+ weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ K = nx.exp(-M / reg)
+
+ err = 1
+
+ UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
+
+ u = (geometricMean(UKv) / UKv.T).T
+ c = nx.ones(A.shape[0], type_as=A)
+ bar = nx.ones(A.shape[0], type_as=A)
+
+ for ii in range(numItermax):
+ bold = bar
+ UKv = nx.dot(K, A / nx.dot(K, u))
+ bar = c * geometricBar(weights, UKv)
+ u = bar[:, None] / UKv
+ c = (c * bar / nx.dot(K, c)) ** 0.5
+
+ if ii % 10 == 9:
+ err = abs(bar - bold).max() / max(bar.max(), 1.)
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ # debiased Sinkhorn does not converge monotonically
+ # guarantee a few iterations are done before stopping
+ if err < stopThr and ii > 20:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return bar, log
+ else:
+ return bar
+
+
+def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False,
+ warn=True):
+ r"""Compute the debiased sinkhorn barycenter in log domain.
+ """
+
+ A, M = list_to_array(A, M)
+ dim, n_hists = A.shape
+
+ nx = get_backend(A, M)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ if weights is None:
+ weights = nx.ones(n_hists, type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ M = - M / reg
+ logA = nx.log(A + 1e-15)
+ log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
+ c = nx.zeros(dim, type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros(dim, type_as=A)
+ for k in range(n_hists):
+ f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
+ log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
+ log_bar += weights[k] * log_KU[:, k]
+ log_bar += c
+ if ii % 10 == 1:
+ err = nx.exp(G + log_KU).std(axis=1).sum()
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if err < stopThr and ii > 20:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+
+ G = log_bar[:, None] - log_KU
+ for _ in range(10):
+ c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0))
+
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
+def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False,
+ warn=True, **kwargs):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
+ where :math:`\mathbf{A}` is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions
+ of matrix :math:`\mathbf{A}`
+ - `reg` is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm
+ as proposed in :ref:`[21] <references-convolutional-barycenter-2d>`
+
+ Parameters
+ ----------
+ A : array-like, shape (n_hists, width, height)
+ `n` distributions (2D images) of size `width` x `height`
reg : float
Regularization term >0
- weights : ndarray, shape (n_hists,)
+ weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
+ method : string, optional
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : ndarray, shape (width, height)
+ a : array-like, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-convolutional-barycenter-2d:
References
----------
- .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
- Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
- ACM Transactions on Graphics (TOG), 34(4), 66
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher,
+ A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances:
+ Efficient optimal transportation on geometric domains. ACM Transactions
+ on Graphics (TOG), 34(4), 66
+
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _convolutional_barycenter2d(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _convolutional_barycenter2d_log(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-9, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
"""
+ A = list_to_array(A)
+
+ nx = get_backend(A)
+
if weights is None:
- weights = np.ones(A.shape[0]) / A.shape[0]
+ weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
else:
assert (len(weights) == A.shape[0])
if log:
log = {'err': []}
- b = np.zeros_like(A[0, :, :])
- U = np.ones_like(A)
- KV = np.ones_like(A)
-
- cpt = 0
+ bar = nx.ones(A.shape[1:], type_as=A)
+ bar /= bar.sum()
+ U = nx.ones(A.shape, type_as=A)
+ V = nx.ones(A.shape, type_as=A)
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
- t = np.linspace(0, 1, A.shape[1])
- [Y, X] = np.meshgrid(t, t)
- xi1 = np.exp(-(X - Y) ** 2 / reg)
+ t = nx.linspace(0, 1, A.shape[1])
+ [Y, X] = nx.meshgrid(t, t)
+ K1 = nx.exp(-(X - Y) ** 2 / reg)
+
+ t = nx.linspace(0, 1, A.shape[2])
+ [Y, X] = nx.meshgrid(t, t)
+ K2 = nx.exp(-(X - Y) ** 2 / reg)
+
+ def convol_imgs(imgs):
+ kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+ kxy = nx.einsum("...ij,klj->kli", K2, kx)
+ return kxy
+
+ KU = convol_imgs(U)
+ for ii in range(numItermax):
+ V = bar[None] / KU
+ KV = convol_imgs(V)
+ U = A / KV
+ KU = convol_imgs(U)
+ bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ if ii % 10 == 9:
+ err = (V * KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ log['U'] = U
+ return bar, log
+ else:
+ return bar
+
+
+def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-4, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images in log-domain.
+ """
+
+ A = list_to_array(A)
+
+ nx = get_backend(A)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ n_hists, width, height = A.shape
+
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == n_hists)
+
+ if log:
+ log = {'err': []}
+
+ err = 1
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ M1 = - (X - Y) ** 2 / reg
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ M2 = - (X - Y) ** 2 / reg
+
+ def convol_img(log_img):
+ log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
+ log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
+ return log_img
+
+ logA = nx.log(A + stabThr)
+ log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros((width, height), type_as=A)
+ for k in range(n_hists):
+ f = logA[k] - convol_img(G[k])
+ log_KU[k] = convol_img(f)
+ log_bar = log_bar + weights[k] * log_KU[k]
+
+ if ii % 10 == 9:
+ err = nx.exp(G + log_KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ G = log_bar[None, :, :] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
+def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn",
+ numItermax=10000, stopThr=1e-3,
+ verbose=False, log=False, warn=True,
+ **kwargs):
+ r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}`
+ where :math:`\mathbf{A}` is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.barycenter_debiased`)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two
+ dimensions of matrix :math:`\mathbf{A}`
+ - `reg` is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the debiased Sinkhorn scaling
+ algorithm as proposed in :ref:`[37] <references-convolutional-barycenter2d-debiased>`
+
+ Parameters
+ ----------
+ A : array-like, shape (n_hists, width, height)
+ `n` distributions (2D images) of size `width` x `height`
+ reg : float
+ Regularization term >0
+ weights : array-like, shape (n_hists,)
+ Weights of each image on the simplex (barycentric coodinates)
+ method : string, optional
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+
+ Returns
+ -------
+ a : array-like, shape (width, height)
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-convolutional-barycenter2d-debiased:
+ References
+ ----------
+
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+ Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
- t = np.linspace(0, 1, A.shape[2])
- [Y, X] = np.meshgrid(t, t)
- xi2 = np.exp(-(X - Y) ** 2 / reg)
+ if method.lower() == 'sinkhorn':
+ return _convolutional_barycenter2d_debiased(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
- def K(x):
- return np.dot(np.dot(xi1, x), xi2)
- while (err > stopThr and cpt < numItermax):
+def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-3, stabThr=1e-15, verbose=False,
+ log=False, warn=True):
+ r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions.
+ """
- bold = b
- cpt = cpt + 1
+ A = list_to_array(A)
+ n_hists, width, height = A.shape
- b = np.zeros_like(A[0, :, :])
- for r in range(A.shape[0]):
- KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
- b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
- b = np.exp(b)
- for r in range(A.shape[0]):
- U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
+ nx = get_backend(A)
- if cpt % 10 == 1:
- err = np.sum(np.abs(bold - b))
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == n_hists)
+
+ if log:
+ log = {'err': []}
+
+ bar = nx.ones((width, height), type_as=A)
+ bar /= width * height
+ U = nx.ones(A.shape, type_as=A)
+ V = nx.ones(A.shape, type_as=A)
+ c = nx.ones(A.shape[1:], type_as=A)
+ err = 1
+
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ K1 = nx.exp(-(X - Y) ** 2 / reg)
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ K2 = nx.exp(-(X - Y) ** 2 / reg)
+
+ def convol_imgs(imgs):
+ kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+ kxy = nx.einsum("...ij,klj->kli", K2, kx)
+ return kxy
+
+ KU = convol_imgs(U)
+ for ii in range(numItermax):
+ V = bar[None] / KU
+ KV = convol_imgs(V)
+ U = A / KV
+ KU = convol_imgs(U)
+ bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+
+ for _ in range(10):
+ c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5
+
+ if ii % 10 == 9:
+ err = (V * KU).std(axis=0).sum()
# log and verbose print
if log:
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
+ # debiased Sinkhorn does not converge monotonically
+ # guarantee a few iterations are done before stopping
+ if err < stopThr and ii > 20:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['U'] = U
- return b, log
+ return bar, log
+ else:
+ return bar
+
+
+def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-3, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the debiased barycenter of 2D images in log-domain.
+ """
+
+ A = list_to_array(A)
+ n_hists, width, height = A.shape
+ nx = get_backend(A)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ err = 1
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ M1 = - (X - Y) ** 2 / reg
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ M2 = - (X - Y) ** 2 / reg
+
+ def convol_img(log_img):
+ log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
+ log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
+ return log_img
+
+ logA = nx.log(A + stabThr)
+ log_bar, c = nx.zeros((2, width, height), type_as=A)
+ log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros((width, height), type_as=A)
+ for k in range(n_hists):
+ f = logA[k] - convol_img(G[k])
+ log_KU[k] = convol_img(f)
+ log_bar = log_bar + weights[k] * log_KU[k]
+ log_bar += c
+ for _ in range(10):
+ c = 0.5 * (c + log_bar - convol_img(c))
+
+ if ii % 10 == 9:
+ err = nx.exp(G + log_KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr and ii > 20:
+ break
+ G = log_bar[None, :, :] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
else:
- return b
+ return nx.exp(log_bar)
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
- stopThr=1e-3, verbose=False, log=False):
+ stopThr=1e-3, verbose=False, log=False, warn=True):
r"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
The function solve the following optimization problem:
.. math::
- \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h})
+
+ \mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad
+ (1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) +
+ \alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h})
where :
- - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
- - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`,
+ its expected shape is `(dim_a, n_atoms)`
- :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
- :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
- - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
- - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
- - :math:`\\alpha`weight data fitting and regularization
+ - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the
+ cost matrix (`dim_a`, `dim_a`) for OT data fitting
+ - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization
+ term and the cost matrix (`dim_prior`, `n_atoms`) regularization
+ - :math:`\alpha` weight data fitting and regularization
- The optimization problem is solved suing the algorithm described in [4]
+ The optimization problem is solved following the algorithm described
+ in :ref:`[4] <references-unmix>`
Parameters
----------
- a : ndarray, shape (dim_a)
+ a : array-like, shape (dim_a)
observed distribution (histogram, sums to 1)
- D : ndarray, shape (dim_a, n_atoms)
+ D : array-like, shape (dim_a, n_atoms)
dictionary matrix
- M : ndarray, shape (dim_a, dim_a)
+ M : array-like, shape (dim_a, dim_a)
loss matrix
- M0 : ndarray, shape (n_atoms, dim_prior)
+ M0 : array-like, shape (n_atoms, dim_prior)
loss matrix
- h0 : ndarray, shape (n_atoms,)
+ h0 : array-like, shape (n_atoms,)
prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
@@ -1433,105 +2461,125 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
-
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- h : ndarray, shape (n_atoms,)
+ h : array-like, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-unmix:
References
----------
- .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
-
+ .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
+ Supervised planetary unmixing with optimal transport, Whorkshop
+ on Hyperspectral Image and Signal Processing :
+ Evolution in Remote Sensing (WHISPERS), 2016.
"""
+ a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0)
+
+ nx = get_backend(a, D, M, M0, h0)
+
# M = M/np.median(M)
- K = np.exp(-M / reg)
+ K = nx.exp(-M / reg)
# M0 = M0/np.median(M0)
- K0 = np.exp(-M0 / reg0)
+ K0 = nx.exp(-M0 / reg0)
old = h0
err = 1
- cpt = 0
# log = {'niter':0, 'all_err':[]}
if log:
log = {'err': []}
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
K = projC(K, a)
K0 = projC(K0, h0)
- new = np.sum(K0, axis=1)
+ new = nx.sum(K0, axis=1)
# we recombine the current selection from dictionnary
- inv_new = np.dot(D, new)
- other = np.sum(K, axis=1)
+ inv_new = nx.dot(D, new)
+ other = nx.sum(K, axis=1)
# geometric interpolation
- delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new))
+ delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
- K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0)
+ K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0)
- err = np.linalg.norm(np.sum(K0, axis=1) - old)
+ err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt = cpt + 1
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Unmixing algorithm did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
- return np.sum(K0, axis=1), log
+ log['niter'] = ii
+ return nx.sum(K0, axis=1), log
else:
- return np.sum(K0, axis=1)
+ return nx.sum(K0, axis=1)
def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
- stopThr=1e-6, verbose=False, log=False, **kwargs):
- r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
+ stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
+ r'''Joint OT and proportion estimation for multi-source target shift as
+ proposed in :ref:`[27] <references-jcpot-barycenter>`
The function solves the following optimization problem:
.. math::
- \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
+ \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k
W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
where :
- - :math:`\lambda_k` is the weight of k-th source domain
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
- - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\lambda_k` is the weight of `k`-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain
+ defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape
+ is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source
+ domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C`
- :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
- - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in
+ [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)`
- The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+ The problem consist in solving a Wasserstein barycenter problem to estimate
+ the proportions :math:`\mathbf{h}` in the target domain.
The algorithm used for solving the problem is the Iterative Bregman projections algorithm
- with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution.
+ with two sets of marginal constraints related to the unknown vector
+ :math:`\mathbf{h}` and uniform target distribution.
Parameters
----------
- Xs : list of K np.ndarray(nsk,d)
+ Xs : list of K array-like(nsk,d)
features of all source domains' samples
- Ys : list of K np.ndarray(nsk,)
+ Ys : list of K array-like(nsk,)
labels of all source domains' samples
- Xt : np.ndarray (nt,d)
+ Xt : array-like (nt,d)
samples in the target domain
reg : float
Regularization term > 0
@@ -1541,28 +2589,37 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Max number of iterations
stopThr : float, optional
Stop threshold on relative change in the barycenter (>0)
- log : bool, optional
- record log if True
verbose : bool, optional (default=False)
Controls the verbosity of the optimization algorithm
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- h : (C,) ndarray
+ h : (C,) array-like
proportion estimation in the target domain
log : dict
log dictionary return only if log==True in parameters
+ .. _references-jcpot-barycenter:
References
----------
.. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
- "Optimal transport for multi-source domain adaptation under target shift",
- International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
-
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
'''
- nbclasses = len(np.unique(Ys[0]))
+
+ Xs = list_to_array(*Xs)
+ Ys = list_to_array(*Ys)
+ Xt = list_to_array(Xt)
+
+ nx = get_backend(*Xs, *Ys, Xt)
+
+ nbclasses = len(nx.unique(Ys[0]))
nbdomains = len(Xs)
# log dictionary
@@ -1579,19 +2636,19 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
dom = {}
nsk = Xs[d].shape[0] # get number of elements for this domain
dom['nbelem'] = nsk
- classes = np.unique(Ys[d]) # get number of classes for this domain
+ classes = nx.unique(Ys[d]) # get number of classes for this domain
# format classes to start from 0 for convenience
- if np.min(classes) != 0:
- Ys[d] = Ys[d] - np.min(classes)
- classes = np.unique(Ys[d])
+ if nx.min(classes) != 0:
+ Ys[d] -= nx.min(classes)
+ classes = nx.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- Dtmp1 = np.zeros((nbclasses, nsk))
- Dtmp2 = np.zeros((nbclasses, nsk))
+ Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
+ Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
for c in classes:
- nbelemperclass = np.sum(Ys[d] == c)
+ nbelemperclass = nx.sum(Ys[d] == c)
if nbelemperclass != 0:
Dtmp1[int(c), Ys[d] == c] = 1.
Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
@@ -1602,51 +2659,54 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Mtmp = dist(Xs[d], Xt, metric=metric)
M.append(Mtmp)
- Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
- np.divide(Mtmp, -reg, out=Ktmp)
- np.exp(Ktmp, out=Ktmp)
+ Ktmp = nx.exp(-Mtmp / reg)
K.append(Ktmp)
# uniform target distribution
- a = unif(np.shape(Xt)[0])
+ a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0])
- cpt = 0 # iterations count
err = 1
- old_bary = np.ones((nbclasses))
+ old_bary = nx.ones((nbclasses,), type_as=Xs[0])
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
- bary = np.zeros((nbclasses))
+ bary = nx.zeros((nbclasses,), type_as=Xs[0])
# update coupling matrices for marginal constraints w.r.t. uniform target distribution
for d in range(nbdomains):
K[d] = projC(K[d], a)
- other = np.sum(K[d], axis=1)
- bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
+ other = nx.sum(K[d], axis=1)
+ bary += nx.log(nx.dot(D1[d], other)) / nbdomains
- bary = np.exp(bary)
+ bary = nx.exp(bary)
# update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
for d in range(nbdomains):
- new = np.dot(D2[d].T, bary)
+ new = nx.dot(D2[d].T, bary)
K[d] = projR(K[d], new)
- err = np.linalg.norm(bary - old_bary)
- cpt = cpt + 1
+ err = nx.norm(bary - old_bary)
+
old_bary = bary
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- bary = bary / np.sum(bary)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Algorithm did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ bary = bary / nx.sum(bary)
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['M'] = M
log['D1'] = D1
log['D2'] = D2
@@ -1657,8 +2717,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, verbose=False,
- log=False, **kwargs):
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
+ log=False, warn=True, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -1666,45 +2726,56 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return
+ the dual potentials only (to save memory). If False, calculate full
+ cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batches used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
+ gamma : array-like, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1715,9 +2786,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
>>> n_samples_a = 2
>>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
+ >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1725,30 +2796,115 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = nx.from_numpy(unif(nt), type_as=X_s)
+
+ if isLazy:
+ if log:
+ dict_log = {"err": []}
- M = dist(X_s, X_t, metric=metric)
+ log_a, log_b = nx.log(a), nx.log(b)
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+
+ if isinstance(batchSize, int):
+ bs, bt = batchSize, batchSize
+ elif isinstance(batchSize, tuple) and len(batchSize) == 2:
+ bs, bt = batchSize[0], batchSize[1]
+ else:
+ raise ValueError("Batch size must be in integer or a tuple of two integers")
+
+ range_s, range_t = range(0, ns, bs), range(0, nt, bt)
+
+ lse_f = nx.zeros((ns,), type_as=a)
+ lse_g = nx.zeros((nt,), type_as=a)
+
+ X_s_np = nx.to_numpy(X_s)
+ X_t_np = nx.to_numpy(X_t)
+
+ for i_ot in range(numIterMax):
+
+ lse_f_cols = []
+ for i in range_s:
+ M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ lse_f_cols.append(
+ nx.logsumexp(g[None, :] - M / reg, axis=1)
+ )
+ lse_f = nx.concatenate(lse_f_cols, axis=0)
+ f = log_a - lse_f
+
+ lse_g_cols = []
+ for j in range_t:
+ M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ lse_g_cols.append(
+ nx.logsumexp(f[:, None] - M / reg, axis=0)
+ )
+ lse_g = nx.concatenate(lse_g_cols, axis=0)
+ g = log_b - lse_g
+
+ if (i_ot + 1) % 10 == 0:
+ m1_cols = []
+ for i in range_s:
+ M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ m1_cols.append(
+ nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ )
+ m1 = nx.concatenate(m1_cols, axis=0)
+ err = nx.sum(nx.abs(m1 - a))
+ if log:
+ dict_log["err"].append(err)
+
+ if verbose and (i_ot + 1) % 100 == 0:
+ print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+
+ if err <= stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ dict_log["u"] = f
+ dict_log["v"] = g
+ return (f, g, dict_log)
+ else:
+ return (f, g)
- if log:
- pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
- return pi, log
else:
- pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
- return pi
+ M = dist(X_s, X_t, metric=metric)
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=False, **kwargs)
+ return pi
-def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9, isLazy=False,
+ batchSize=100, verbose=False, log=False, warn=True, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1757,46 +2913,57 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return
+ the dual potentials only (to save memory). If False, calculate
+ full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batches used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (n_hists) array-like or float
+ Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1806,41 +2973,94 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
>>> n_samples_a = 2
>>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
- array([4.53978687e-05])
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
+ >>> b = np.full((n_samples_b, 3), 1/n_samples_b)
+ >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False)
+ array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05])
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling
+ Algorithms for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = nx.from_numpy(unif(nt), type_as=X_s)
- M = dist(X_s, X_t, metric=metric)
+ if isLazy:
+ if log:
+ f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
+ isLazy=isLazy,
+ batchSize=batchSize,
+ verbose=verbose, log=log,
+ warn=warn)
+ else:
+ f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
+ numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize,
+ verbose=verbose, log=log,
+ warn=warn)
+
+ bs = batchSize if isinstance(batchSize, int) else batchSize[0]
+ range_s = range(0, ns, bs)
+
+ loss = 0
+
+ X_s_np = nx.to_numpy(X_s)
+ X_t_np = nx.to_numpy(X_t)
+
+ for i in range_s:
+ M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M_block = nx.from_numpy(M_block, type_as=a)
+ pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
+ loss += nx.sum(M_block * pi_block)
+
+ if log:
+ return loss, dict_log
+ else:
+ return loss
- if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss
+ M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ return sinkhorn_loss
-def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9,
+ verbose=False, log=False, warn=True,
+ **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -1849,64 +3069,72 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
.. math::
- W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a)
+ W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma_a)
- W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
+ W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma_b)
- S &= W - 1/2 * (W_a + W_b)
+ S &= W - \frac{W_a + W_b}{2}
.. math::
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
- \gamma_a 1 = a
+ \gamma_a \mathbf{1} &= \mathbf{a}
- \gamma_a^T 1= a
+ \gamma_a^T \mathbf{1} &= \mathbf{a}
- \gamma_a\geq 0
+ \gamma_a &\geq 0
- \gamma_b 1 = b
+ \gamma_b \mathbf{1} &= \mathbf{b}
- \gamma_b^T 1= b
+ \gamma_b^T \mathbf{1} &= \mathbf{b}
- \gamma_b\geq 0
+ \gamma_b &\geq 0
where :
- - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`)
+ is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`))
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (1,) array-like
+ Optimal transportation symmetrized loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1915,27 +3143,36 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
>>> n_samples_a = 2
>>> n_samples_b = 4
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
- array([1.499...])
+ 1.499887176049052
References
----------
- .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+ .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative
+ Models with Sinkhorn Divergences, Proceedings of the Twenty-First
+ International Conference on Artficial Intelligence and Statistics,
+ (AISTATS) 21, 2018
'''
if log:
- sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -1948,99 +3185,119 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
return max(0, sinkhorn_div), log
else:
- sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ return max(0, sinkhorn_div)
- sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
+ restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09,
+ verbose=False, log=False):
+ r"""
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
- return max(0, sinkhorn_div)
+ The function solves an approximate dual of Sinkhorn divergence :ref:`[2]
+ <references-screenkhorn>` which is written as the following optimization problem:
+
+ .. math::
+ (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \quad
+ \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} -
+ \langle \kappa \mathbf{u}, \mathbf{a} \rangle -
+ \langle \frac{1}{\kappa} \mathbf{v}, \mathbf{b} \rangle
-def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
- maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
- r""""
- Screening Sinkhorn Algorithm for Regularized Optimal Transport
+ where:
- The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem:
+ .. math::
- ..math::
- (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
+ \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and}
- where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and
+ .. math::
- s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns}
+ s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\}
- e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt}
+ e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\}
- The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26]
+ The parameters `kappa` and `epsilon` are determined w.r.t the couple number
+ budget of points (`ns_budget`, `nt_budget`), see Equation (5)
+ in :ref:`[26] <references-screenkhorn>`
Parameters
----------
- a : `numpy.ndarray`, shape=(ns,)
+ a: array-like, shape=(ns,)
samples weights in the source domain
-
- b : `numpy.ndarray`, shape=(nt,)
+ b: array-like, shape=(nt,)
samples weights in the target domain
-
- M : `numpy.ndarray`, shape=(ns, nt)
+ M: array-like, shape=(ns, nt)
Cost matrix
-
- reg : `float`
+ reg: `float`
Level of the entropy regularisation
-
- ns_budget : `int`, deafult=None
- Number budget of points to be keeped in the source domain
- If it is None then 50% of the source sample points will be keeped
-
- nt_budget : `int`, deafult=None
- Number budget of points to be keeped in the target domain
- If it is None then 50% of the target sample points will be keeped
-
- uniform : `bool`, default=False
- If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt
-
+ ns_budget: `int`, default=None
+ Number budget of points to be kept in the source domain.
+ If it is None then 50% of the source sample points will be kept
+ nt_budget: `int`, default=None
+ Number budget of points to be kept in the target domain.
+ If it is None then 50% of the target sample points will be kept
+ uniform: `bool`, default=False
+ If `True`, the source and target distribution are supposed to be uniform,
+ i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt`
restricted : `bool`, default=True
If `True`, a warm-start initialization for the L-BFGS-B solver
using a restricted Sinkhorn algorithm with at most 5 iterations
-
- maxiter : `int`, default=10000
+ maxiter: `int`, default=10000
Maximum number of iterations in LBFGS solver
+ maxfun: `int`, default=10000
+ Maximum number of function evaluations in LBFGS solver
+ pgtol: `float`, default=1e-09
+ Final objective function accuracy in LBFGS solver
+ verbose: `bool`, default=False
+ If `True`, display informations about the cardinals of the active sets
+ and the parameters kappa and epsilon
- maxfun : `int`, default=10000
- Maximum number of function evaluations in LBFGS solver
- pgtol : `float`, default=1e-09
- Final objective function accuracy in LBFGS solver
+ .. admonition:: Dependency
- verbose : `bool`, default=False
- If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa
- and epsilon
+ To gain more efficiency, :py:func:`ot.bregman.screenkhorn` needs to call the "Bottleneck"
+ package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step.
- Dependency
- ----------
- To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/)
- in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears:
- "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
+ If Bottleneck isn't installed, the following error message appears:
+
+ "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
Returns
-------
- gamma : `numpy.ndarray`, shape=(ns, nt)
+ gamma : array-like, shape=(ns, nt)
Screened optimal transportation matrix for the given parameters
log : `dict`, default=False
Log dictionary return only if log==True in parameters
+ .. _references-screenkhorn:
References
-----------
- .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport,
+ Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
"""
# check if bottleneck module exists
@@ -2048,12 +3305,17 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
import bottleneck
except ImportError:
warnings.warn(
- "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
+ "Bottleneck module is not installed. Install it from"
+ " https://pypi.org/project/Bottleneck/ for better performance.")
bottleneck = np
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+ if nx.__name__ == "jax":
+ raise TypeError("JAX arrays have been received but screenkhorn is not "
+ "compatible with JAX.")
+
ns, nt = M.shape
# by default, we keep only 50% of the sample data points
@@ -2063,9 +3325,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
nt_budget = int(np.floor(0.5 * nt))
# calculate the Gibbs kernel
- K = np.empty_like(M)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
def projection(u, epsilon):
u[u <= epsilon] = epsilon
@@ -2077,8 +3337,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
if ns_budget == ns and nt_budget == nt:
# full number of budget points (ns, nt) = (ns_budget, nt_budget)
- Isel = np.ones(ns, dtype=bool)
- Jsel = np.ones(nt, dtype=bool)
+ Isel = nx.from_numpy(np.ones(ns, dtype=bool))
+ Jsel = nx.from_numpy(np.ones(nt, dtype=bool))
epsilon = 0.0
kappa = 1.0
@@ -2094,57 +3354,63 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
K_IJc = []
K_IcJ = []
- vec_eps_IJc = np.zeros(nt)
- vec_eps_IcJ = np.zeros(ns)
+ vec_eps_IJc = nx.zeros((nt,), type_as=M)
+ vec_eps_IcJ = nx.zeros((ns,), type_as=M)
else:
# sum of rows and columns of K
- K_sum_cols = K.sum(axis=1)
- K_sum_rows = K.sum(axis=0)
+ K_sum_cols = nx.sum(K, axis=1)
+ K_sum_rows = nx.sum(K, axis=0)
if uniform:
if ns / ns_budget < 4:
- aK_sort = np.sort(K_sum_cols)
+ aK_sort = nx.sort(K_sum_cols)
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
- aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1]
+ aK_sort = nx.from_numpy(
+ bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1],
+ type_as=M
+ )
epsilon_u_square = a[0] / aK_sort
if nt / nt_budget < 4:
- bK_sort = np.sort(K_sum_rows)
+ bK_sort = nx.sort(K_sum_rows)
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
- bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1]
+ bK_sort = nx.from_numpy(
+ bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1],
+ type_as=M
+ )
epsilon_v_square = b[0] / bK_sort
else:
aK = a / K_sum_cols
bK = b / K_sum_rows
- aK_sort = np.sort(aK)[::-1]
+ aK_sort = nx.flip(nx.sort(aK), axis=0)
epsilon_u_square = aK_sort[ns_budget - 1]
- bK_sort = np.sort(bK)[::-1]
+ bK_sort = nx.flip(nx.sort(bK), axis=0)
epsilon_v_square = bK_sort[nt_budget - 1]
# active sets I and J (see Lemma 1 in [26])
Isel = a >= epsilon_u_square * K_sum_cols
Jsel = b >= epsilon_v_square * K_sum_rows
- if sum(Isel) != ns_budget:
+ if nx.sum(Isel) != ns_budget:
if uniform:
aK = a / K_sum_cols
- aK_sort = np.sort(aK)[::-1]
- epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean()
+ aK_sort = nx.flip(nx.sort(aK), axis=0)
+ epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1])
Isel = a >= epsilon_u_square * K_sum_cols
- ns_budget = sum(Isel)
+ ns_budget = nx.sum(Isel)
- if sum(Jsel) != nt_budget:
+ if nx.sum(Jsel) != nt_budget:
if uniform:
bK = b / K_sum_rows
- bK_sort = np.sort(bK)[::-1]
- epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean()
+ bK_sort = nx.flip(nx.sort(bK), axis=0)
+ epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1])
Jsel = b >= epsilon_v_square * K_sum_rows
- nt_budget = sum(Jsel)
+ nt_budget = nx.sum(Jsel)
epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)
@@ -2152,7 +3418,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
if verbose:
print("epsilon = %s\n" % epsilon)
print("kappa = %s\n" % kappa)
- print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel)))
+ print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n'
+ % (sum(Isel), sum(Jsel)))
# Ic, Jc: complementary of the active sets I and J
Ic = ~Isel
@@ -2162,18 +3429,18 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
K_IcJ = K[np.ix_(Ic, Jsel)]
K_IJc = K[np.ix_(Isel, Jc)]
- K_min = K_IJ.min()
+ K_min = nx.min(K_IJ)
if K_min == 0:
- K_min = np.finfo(float).tiny
+ K_min = float(np.finfo(float).tiny)
# a_I, b_J, a_Ic, b_Jc
a_I = a[Isel]
b_J = b[Jsel]
if not uniform:
- a_I_min = a_I.min()
- a_I_max = a_I.max()
- b_J_max = b_J.max()
- b_J_min = b_J.min()
+ a_I_min = nx.min(a_I)
+ a_I_max = nx.max(a_I)
+ b_J_max = nx.max(b_J)
+ b_J_min = nx.min(b_J)
else:
a_I_min = a_I[0]
a_I_max = a_I[0]
@@ -2182,33 +3449,37 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
# box constraints in L-BFGS-B (see Proposition 1 in [26])
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
- ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
+ ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
bounds_v = [(
- max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
- epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+ max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
# pre-calculated constants for the objective
- vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
- vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0)
+ vec_eps_IJc = epsilon * kappa * nx.sum(
+ K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :],
+ axis=1
+ )
+ vec_eps_IcJ = (epsilon / kappa) * nx.sum(
+ nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ,
+ axis=0
+ )
# initialisation
- u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa)
- v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa)
+ u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M)
+ v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M)
# pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26])
if restricted:
if ns_budget != ns or nt_budget != nt:
- cst_u = kappa * epsilon * K_IJc.sum(axis=1)
- cst_v = epsilon * K_IcJ.sum(axis=0) / kappa
+ cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1)
+ cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa
- cpt = 1
- while cpt < 5: # 5 iterations
- K_IJ_v = np.dot(K_IJ.T, u0) + cst_v
+ for _ in range(5): # 5 iterations
+ K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v
v0 = b_J / (kappa * K_IJ_v)
- KIJ_u = np.dot(K_IJ, v0) + cst_u
+ KIJ_u = nx.dot(K_IJ, v0) + cst_u
u0 = (kappa * a_I) / KIJ_u
- cpt += 1
u0 = projection(u0, epsilon / kappa)
v0 = projection(v0, epsilon * kappa)
@@ -2219,15 +3490,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
def restricted_sinkhorn(usc, vsc, max_iter=5):
"""
- Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26])
+ Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B)
"""
- cpt = 1
- while cpt < max_iter:
- K_IJ_v = np.dot(K_IJ.T, usc) + cst_v
+ for _ in range(max_iter):
+ K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v
vsc = b_J / (kappa * K_IJ_v)
- KIJ_u = np.dot(K_IJ, vsc) + cst_u
+ KIJ_u = nx.dot(K_IJ, vsc) + cst_u
usc = (kappa * a_I) / KIJ_u
- cpt += 1
usc = projection(usc, epsilon / kappa)
vsc = projection(vsc, epsilon * kappa)
@@ -2235,17 +3504,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
return usc, vsc
def screened_obj(usc, vsc):
- part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J,
- np.log(vsc))
- part_IJc = np.dot(usc, vec_eps_IJc)
- part_IcJ = np.dot(vec_eps_IcJ, vsc)
+ part_IJ = (
+ nx.dot(nx.dot(usc, K_IJ), vsc)
+ - kappa * nx.dot(a_I, nx.log(usc))
+ - (1. / kappa) * nx.dot(b_J, nx.log(vsc))
+ )
+ part_IJc = nx.dot(usc, vec_eps_IJc)
+ part_IcJ = nx.dot(vec_eps_IcJ, vsc)
psi_epsilon = part_IJ + part_IJc + part_IcJ
return psi_epsilon
def screened_grad(usc, vsc):
# gradients of Psi_(kappa,epsilon) w.r.t u and v
- grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
- grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
+ grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
+ grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
return grad_u, grad_v
def bfgspost(theta):
@@ -2255,20 +3527,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
f = screened_obj(u, v)
# gradient
g_u, g_v = screened_grad(u, v)
- g = np.hstack([g_u, g_v])
- return f, g
+ g = nx.concatenate([g_u, g_v], axis=0)
+ return nx.to_numpy(f), nx.to_numpy(g)
# ----------------------------------------------------------------------------------------------------------------#
# Step 2: L-BFGS-B solver #
# ----------------------------------------------------------------------------------------------------------------#
u0, v0 = restricted_sinkhorn(u0, v0)
- theta0 = np.hstack([u0, v0])
+ theta0 = nx.concatenate([u0, v0], axis=0)
bounds = bounds_u + bounds_v # constraint bounds
def obj(theta):
- return bfgspost(theta)
+ return bfgspost(nx.from_numpy(theta, type_as=M))
theta, _, _ = fmin_l_bfgs_b(func=obj,
x0=theta0,
@@ -2276,12 +3548,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
maxfun=maxfun,
pgtol=pgtol,
maxiter=maxiter)
+ theta = nx.from_numpy(theta, type_as=M)
usc = theta[:ns_budget]
vsc = theta[ns_budget:]
- usc_full = np.full(ns, epsilon / kappa)
- vsc_full = np.full(nt, epsilon * kappa)
+ usc_full = nx.full((ns,), epsilon / kappa, type_as=M)
+ vsc_full = nx.full((nt,), epsilon * kappa, type_as=M)
usc_full[Isel] = usc
vsc_full[Jsel] = vsc
@@ -2293,7 +3566,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
log['Jsel'] = Jsel
gamma = usc_full[:, None] * K * vsc_full[None, :]
- gamma = gamma / gamma.sum()
+ gamma = gamma / nx.sum(gamma)
if log:
return gamma, log