diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-06-01 10:10:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-01 10:10:54 +0200 |
commit | 184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch) | |
tree | 483a7274c91030fd644de49b03a5fad04af9deba /ot/utils.py | |
parent | 1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff) |
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends
* stat sets on functions
* proper import
* install recent torch on windows
* install recent torch on windows
* now testing all functions in backedn
* add jax backedn
* clenaup windowds
* proper convert for jax backedn
* pep8
* try again windows tests
* test jax conversion
* try proper widows tests
* emd fuction ses backedn
* better test partial OT
* proper tests to_numpy and teplate Backend
* pep8
* pep8 x2
* feaking sinkhorn works with torch
* sinkhorn2 compatible
* working ot.emd2
* important detach
* it should work
* jax autodiff emd
* pep8
* no tast same for jax
* new independat tests per backedn
* freaking pep8
* add tests for gradients
* deprecate ot.gpu
* worging dist function
* working dist
* dist done in backedn
* not in
* remove indexing
* change accuacy for jax
* first pull backend
* projection simplex
* projection simplex
* projection simplex
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8
* add backedn discusion to quickstart guide
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8 + better doc
* proper links
* corect doctest
* big debug documentation
* doctest again
* doctest again bis
* doctest again ter (last one or i kill myself)
* backend test + doc proj simplex
* correction test_utils
* correction test_utils
* correction cumsum
* correction flip
* correction flip v2
* more debug
* more debug
* more debug + pep8
* pep8
* argh
* proj_simplex
* backedn works for sort
* proj simplex
* jax sucks
* update doc
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/lp/__init__.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* begin comment alex
* comment alex part 2
* optimize test gromov
* proj_simplex on vectors
* add awesome gradient decsnt example on the weights
* pep98 of course
* proof read example by alex
* pep8 again
* encoding oos in translation
* correct legend
Co-authored-by: Nicolas Courty <ncourty@irisa.fr>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 128 |
1 files changed, 107 insertions, 21 deletions
diff --git a/ot/utils.py b/ot/utils.py index 544c569..4dac0c5 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() @@ -41,8 +42,11 @@ def toq(): def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K @@ -52,6 +56,66 @@ def laplacian(x): return L +def list_to_array(*lst): + """ Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r""" compute the closest point (orthogonal projection) on the + generalized (n-1)-simplex of a vector v wrt. to the Euclidean + distance, thus solving: + .. math:: + \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + + s.t. \gamma^T 1= z + + \gamma\geq 0 + + If v is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Parameters + ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (n,d) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + def unif(n): """ return a uniform histogram of length n (simplex) @@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- X : {array-like}, shape (n_samples_1, n_features) Y : {array-like}, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- distances : {array}, shape (n_samples_1, n_samples_2) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + """Compute distance between samples in x1 and x2 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) + x1 : array-like, shape (n1,d) matrix with n1 samples of size d - x2 : array, shape (n2,d), optional + x2 : array-like, shape (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (n1, n2) distance matrix computed with given metric """ @@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): |