summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2021-10-27 08:41:08 +0200
committerGitHub <noreply@github.com>2021-10-27 08:41:08 +0200
commitd7554331fc409fea48ee758fd630909dd9dc4827 (patch)
tree9b8ed4bf94c12d034d5fb1de5b7b5b76c23b4d05
parent76450dddf8dd62b9714b72e99ae075516246d433 (diff)
[WIP] Sinkhorn in log space (#290)
* adda sinkhorn log and working sinkhorn2 function * more tests pass * more tests pass * it works but not by default yet * remove warningd * update circleci doc * update circleci doc * new sinkhorn implemeted but not by default * better * doctest pass * test doctest * new test utils * remove pep8 errors * remove pep8 errors * doc new implementtaion with log * test sinkhorn 2 * doc for log implementation
-rw-r--r--.circleci/config.yml14
-rw-r--r--README.md4
-rw-r--r--docs/source/quickstart.rst10
-rw-r--r--ot/bregman.py272
-rw-r--r--ot/dr.py4
-rw-r--r--ot/gromov.py4
-rw-r--r--ot/optim.py4
-rw-r--r--ot/utils.py4
-rw-r--r--test/test_bregman.py120
-rw-r--r--test/test_gromov.py10
-rw-r--r--test/test_helpers.py4
-rw-r--r--test/test_utils.py15
12 files changed, 403 insertions, 62 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index e4c71dd..379394a 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -4,7 +4,7 @@ version: 2
jobs:
build_docs:
docker:
- - image: circleci/python:3.7-stretch
+ - image: cimg/python:3.9
steps:
- checkout
- run:
@@ -35,18 +35,6 @@ jobs:
- pip-cache
- run:
- name: Spin up Xvfb
- command: |
- /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset;
-
- # https://github.com/ContinuumIO/anaconda-issues/issues/9190#issuecomment-386508136
- # https://github.com/golemfactory/golem/issues/1019
- - run:
- name: Fix libgcc_s.so.1 pthread_cancel bug
- command: |
- sudo apt-get install qt5-default
-
- - run:
name: Get Python running
command: |
python -m pip install --user --upgrade --progress-bar off pip
diff --git a/README.md b/README.md
index 266d847..ffad0bd 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples):
* [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] .
* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7].
-* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html).
+* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html).
* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4].
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
@@ -290,3 +290,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML).
[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021
+
+[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index fd046a1..232df7b 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -358,6 +358,11 @@ More details about the algorithms used are given in the following note.
+ :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the
classic algorithm [2]_.
+ + :code:`method='sinkhorn_log'` calls :any:`ot.bregman.sinkhorn_log` the
+ sinkhorn algorithm in log space [2]_ that is more stable but can be
+ slower in numpy since `logsumexp` is not implmemented in parallel.
+ It is the recommended solver for applications that requires
+ differentiability with a small number of iterations.
+ :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the
log stabilized version of the algorithm [9]_.
+ :code:`method='sinkhorn_epsilon_scaling'` calls
@@ -389,7 +394,10 @@ More details about the algorithms used are given in the following note.
solutions. Note that the greedy version of the Sinkhorn
:any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the Sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :any:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
diff --git a/ot/bregman.py b/ot/bregman.py
index b59ee1b..2aa76ff 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -64,7 +64,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
solutions. Note that the greedy version of the sinkhorn
:py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a
- fast approximation of the Sinkhorn problem.
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :any:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
@@ -79,8 +82,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
+ those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -118,6 +122,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
@@ -134,6 +139,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log)
@@ -182,7 +191,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ :any:`ot.bregman.sinkhorn_log` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
@@ -190,7 +199,10 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
solutions. Note that the greedy version of the sinkhorn
:any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :any:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
----------
@@ -204,7 +216,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -230,7 +243,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
- array([0.26894142])
+ 0.26894142136999516
.. _references-sinkhorn2:
@@ -243,7 +256,11 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
- .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
+ algorithms for optimal transport via Sinkhorn iteration, Advances in Neural
+ Information Processing Systems (NIPS) 31, 2017
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
@@ -257,20 +274,45 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
"""
- b = list_to_array(b)
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
if len(b.shape) < 2:
- b = b[:, None]
+ if method.lower() == 'sinkhorn':
+ res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+ if log:
+ return nx.sum(M * res[0]), res[1]
+ else:
+ return nx.sum(M * res)
- if method.lower() == 'sinkhorn':
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- elif method.lower() == 'sinkhorn_stabilized':
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
else:
- raise ValueError("Unknown method '%s'." % method)
+
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
@@ -361,7 +403,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# init data
dim_a = len(a)
- dim_b = len(b)
+ dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -438,6 +480,191 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
+def sinkhorn_log(a, b, M, reg, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the entropic regularization optimal transport problem in log space
+ and return the OT matrix
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm :ref:`[2] <references-sinkhorn-knopp>` with the
+ implementation from :ref:`[34] <references-sinkhorn-knopp>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (dim_a,)
+ samples weights in the source domain
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gamma : array-like, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ .. _references-sinkhorn-log:
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+
+ if len(a) == 0:
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
+ if len(b) == 0:
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
+
+ # init data
+ dim_a = len(a)
+ dim_b = b.shape[0]
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if n_hists: # we do not want to use tensors sor we do a loop
+
+ lst_loss = []
+ lst_u = []
+ lst_v = []
+
+ for k in range(n_hists):
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+
+ if log:
+ lst_loss.append(nx.sum(M * res[0]))
+ lst_u.append(res[1]['log_u'])
+ lst_v.append(res[1]['log_v'])
+ else:
+ lst_loss.append(nx.sum(M * res))
+ res = nx.stack(lst_loss)
+ if log:
+ log = {'log_u': nx.stack(lst_u, 1),
+ 'log_v': nx.stack(lst_v, 1), }
+ log['u'] = nx.exp(log['log_u'])
+ log['v'] = nx.exp(log['log_v'])
+ return res, log
+ else:
+ return res
+
+ else:
+
+ if log:
+ log = {'err': []}
+
+ Mr = M / (-reg)
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+
+ def get_logT(u, v):
+ if n_hists:
+ return Mr[:, :, None] + u + v
+ else:
+ return Mr + u[:, None] + v[None, :]
+
+ loga = nx.log(a)
+ logb = nx.log(b)
+
+ cpt = 0
+ err = 1
+ while (err > stopThr and cpt < numItermax):
+
+ v = logb - nx.logsumexp(Mr + u[:, None], 0)
+ u = loga - nx.logsumexp(Mr + v[None, :], 1)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
+ err = nx.norm(tmp2 - b) # violation of marginal
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if log:
+ log['log_u'] = u
+ log['log_v'] = v
+ log['u'] = nx.exp(u)
+ log['v'] = nx.exp(v)
+
+ return nx.exp(get_logT(u, v)), log
+
+ else:
+ return nx.exp(get_logT(u, v))
+
+
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
log=False):
r"""
@@ -1881,8 +2108,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return (f, g)
else:
- M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
- M = nx.from_numpy(M, type_as=a)
+ M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
return pi, log
@@ -2102,7 +2328,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
>>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
- array([1.499...])
+ 1.499887176049052
References
diff --git a/ot/dr.py b/ot/dr.py
index 64588cf..de39662 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -209,11 +209,11 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
.. math::
\max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi)
-
+
- :math:`U` is a linear projection operator in the Stiefel(d, k) manifold
- :math:`H(\pi)` is entropy regularizer
- :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively
-
+
Parameters
----------
X : ndarray, shape (n, d)
diff --git a/ot/gromov.py b/ot/gromov.py
index 85b1549..33b4453 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -1030,7 +1030,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
# compute the gradient
tens = gwggrad(constC, hC1, hC2, T)
- T = sinkhorn(p, q, tens, epsilon)
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
@@ -1204,7 +1204,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-5, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
diff --git a/ot/optim.py b/ot/optim.py
index 6822e4e..34cbb17 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -20,7 +20,7 @@ from .backend import get_backend
def line_search_armijo(f, xk, pk, gfk, old_fval,
args=(), c1=1e-4, alpha0=0.99):
- """
+ r"""
Armijo linesearch function that works with matrices
Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the
@@ -447,7 +447,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
def solve_1d_linesearch_quad(a, b, c):
- """
+ r"""
For any convex or non-convex 1d quadratic function `f`, solve the following problem:
.. math::
diff --git a/ot/utils.py b/ot/utils.py
index 6a782e6..0608aee 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -183,7 +183,7 @@ def euclidean_distances(X, Y, squared=False):
return c
-def dist(x1, x2=None, metric='sqeuclidean'):
+def dist(x1, x2=None, metric='sqeuclidean', p=2):
"""Compute distance between samples in x1 and x2
.. note:: This function is backend-compatible and will work on arrays
@@ -222,7 +222,7 @@ def dist(x1, x2=None, metric='sqeuclidean'):
if not get_backend(x1, x2).__name__ == 'numpy':
raise NotImplementedError()
else:
- return cdist(x1, x2, metric=metric)
+ return cdist(x1, x2, metric=metric, p=p)
def dist0(n, method='lin_square'):
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 942cb6d..c1120ba 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -32,6 +32,27 @@ def test_sinkhorn():
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+def test_sinkhorn_multi_b():
+ # test sinkhorn
+ n = 10
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True)
+
+ loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)]
+ # check constraints
+ np.testing.assert_allclose(
+ loss0, loss, atol=1e-06) # cf convergence sinkhorn
+
+
def test_sinkhorn_backends(nx):
n_samples = 100
n_features = 2
@@ -147,6 +168,7 @@ def test_sinkhorn_variants(nx):
Mb = nx.from_numpy(M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10))
Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
@@ -155,15 +177,73 @@ def test_sinkhorn_variants(nx):
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+
+
+@pytest.skip_backend("jax")
+def test_sinkhorn2_variants_multi_b(nx):
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ ub = nx.from_numpy(u)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M)
+
+ G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+
+ # check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
+ np.testing.assert_allclose(G, Gl, atol=1e-05)
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
def test_sinkhorn_variants_log():
# test sinkhorn
- n = 100
+ n = 50
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -172,6 +252,7 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
@@ -179,9 +260,30 @@ def test_sinkhorn_variants_log():
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
- print(G0, G_green)
+
+
+def test_sinkhorn_variants_log_multib():
+ # test sinkhorn
+ n = 50
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ b = rng.rand(n, 3)
+ b = b / np.sum(b, 0, keepdims=True)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Gl, atol=1e-05)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
@@ -326,10 +428,10 @@ def test_empirical_sinkhorn(nx):
a = ot.unif(n)
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n), (n, 1))
+ X_s = np.reshape(1.0 * np.arange(n), (n, 1))
+ X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
@@ -346,7 +448,7 @@ def test_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski'))
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
@@ -378,7 +480,7 @@ def test_lazy_empirical_sinkhorn(nx):
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
- M_m = ot.dist(X_s, X_t, metric='minkowski')
+ M_m = ot.dist(X_s, X_t, metric='euclidean')
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
@@ -398,7 +500,7 @@ def test_lazy_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 19d61b1..0242d72 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -180,8 +180,8 @@ def test_sampled_gromov():
def test_gromov_barycenter():
- ns = 50
- nt = 60
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -208,8 +208,8 @@ def test_gromov_barycenter():
@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter():
- ns = 20
- nt = 30
+ ns = 10
+ nt = 20
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -222,7 +222,7 @@ def test_gromov_entropic_barycenter():
[ot.unif(ns), ot.unif(nt)
], ot.unif(n_samples), [.5, .5],
'square_loss', 1e-3,
- max_iter=50, tol=1e-5,
+ max_iter=50, tol=1e-3,
verbose=True)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
diff --git a/test/test_helpers.py b/test/test_helpers.py
index 8bd0015..cc4c90e 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -9,8 +9,8 @@ import sys
sys.path.append(os.path.join("ot", "helpers"))
-from openmp_helpers import get_openmp_flag, check_openmp_support # noqa
-from pre_build_helpers import _get_compiler, compile_test_program # noqa
+from openmp_helpers import get_openmp_flag, check_openmp_support # noqa
+from pre_build_helpers import _get_compiler, compile_test_program # noqa
def test_helpers():
diff --git a/test/test_utils.py b/test/test_utils.py
index 60ad5d3..0650ce2 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -7,6 +7,7 @@
import ot
import numpy as np
import sys
+import pytest
def test_proj_simplex(nx):
@@ -108,6 +109,10 @@ def test_dist():
D2 = ot.dist(x, x)
D3 = ot.dist(x)
+ D4 = ot.dist(x, x, metric='minkowski', p=0.5)
+
+ assert D4[0, 1] == D4[1, 0]
+
# dist shoul return squared euclidean
np.testing.assert_allclose(D, D2, atol=1e-14)
np.testing.assert_allclose(D, D3, atol=1e-14)
@@ -220,6 +225,13 @@ def test_deprecated_func():
class Class():
pass
+ with pytest.warns(DeprecationWarning):
+ fun()
+
+ with pytest.warns(DeprecationWarning):
+ cl = Class()
+ print(cl)
+
if sys.version_info < (3, 5):
print('Not tested')
else:
@@ -250,4 +262,7 @@ def test_BaseEstimator():
params['first'] = 'spam again'
cl.set_params(**params)
+ with pytest.raises(ValueError):
+ cl.set_params(bibi=10)
+
assert cl.first == 'spam again'