diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-06-01 10:10:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-01 10:10:54 +0200 |
commit | 184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch) | |
tree | 483a7274c91030fd644de49b03a5fad04af9deba /ot/bregman.py | |
parent | 1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff) |
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends
* stat sets on functions
* proper import
* install recent torch on windows
* install recent torch on windows
* now testing all functions in backedn
* add jax backedn
* clenaup windowds
* proper convert for jax backedn
* pep8
* try again windows tests
* test jax conversion
* try proper widows tests
* emd fuction ses backedn
* better test partial OT
* proper tests to_numpy and teplate Backend
* pep8
* pep8 x2
* feaking sinkhorn works with torch
* sinkhorn2 compatible
* working ot.emd2
* important detach
* it should work
* jax autodiff emd
* pep8
* no tast same for jax
* new independat tests per backedn
* freaking pep8
* add tests for gradients
* deprecate ot.gpu
* worging dist function
* working dist
* dist done in backedn
* not in
* remove indexing
* change accuacy for jax
* first pull backend
* projection simplex
* projection simplex
* projection simplex
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8
* add backedn discusion to quickstart guide
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8 + better doc
* proper links
* corect doctest
* big debug documentation
* doctest again
* doctest again bis
* doctest again ter (last one or i kill myself)
* backend test + doc proj simplex
* correction test_utils
* correction test_utils
* correction cumsum
* correction flip
* correction flip v2
* more debug
* more debug
* more debug + pep8
* pep8
* argh
* proj_simplex
* backedn works for sort
* proj simplex
* jax sucks
* update doc
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/lp/__init__.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* begin comment alex
* comment alex part 2
* optimize test gromov
* proj_simplex on vectors
* add awesome gradient decsnt example on the weights
* pep98 of course
* proof read example by alex
* pep8 again
* encoding oos in translation
* correct legend
Co-authored-by: Nicolas Courty <ncourty@irisa.fr>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 141 |
1 files changed, 71 insertions, 70 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 559db14..b10effd 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,7 +19,8 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b -from ot.utils import unif, dist +from ot.utils import unif, dist, list_to_array +from .backend import get_backend def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -43,17 +44,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -69,25 +89,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - - Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -166,17 +170,35 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. + Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -191,28 +213,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - Returns ------- - W : (n_hists) ndarray + W : (n_hists) float/array-like Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters + Examples -------- @@ -247,7 +255,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] """ - b = np.asarray(b, dtype=np.float64) + + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] @@ -339,14 +348,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) # init data dim_a = len(a) @@ -363,21 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) + K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 @@ -386,13 +387,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, uprev = u vprev = v - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) + KtransposeU = nx.dot(K.T, u) + v = b / KtransposeU + u = 1. / nx.dot(Kp, v) - if (np.any(KtransposeU == 0) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -403,11 +404,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) + tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - np.einsum('i,ij,j->j', u, K, v, out=tmp2) - err = np.linalg.norm(tmp2 - b) # violation of marginal + tmp2 = nx.einsum('i,ij,j->j', u, K, v) + err = nx.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) @@ -422,7 +423,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['v'] = v if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: |