From 184f8f4f7ac78f1dd7f653496d2753211a4e3426 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 1 Jun 2021 10:10:54 +0200 Subject: [MRG] POT numpy/torch/jax backends (#249) * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty Co-authored-by: Alexandre Gramfort --- ot/utils.py | 128 ++++++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 107 insertions(+), 21 deletions(-) (limited to 'ot/utils.py') diff --git a/ot/utils.py b/ot/utils.py index 544c569..4dac0c5 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() @@ -41,8 +42,11 @@ def toq(): def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K @@ -52,6 +56,66 @@ def laplacian(x): return L +def list_to_array(*lst): + """ Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r""" compute the closest point (orthogonal projection) on the + generalized (n-1)-simplex of a vector v wrt. to the Euclidean + distance, thus solving: + .. math:: + \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + + s.t. \gamma^T 1= z + + \gamma\geq 0 + + If v is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Parameters + ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (n,d) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + def unif(n): """ return a uniform histogram of length n (simplex) @@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- X : {array-like}, shape (n_samples_1, n_features) Y : {array-like}, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- distances : {array}, shape (n_samples_1, n_samples_2) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + """Compute distance between samples in x1 and x2 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) + x1 : array-like, shape (n1,d) matrix with n1 samples of size d - x2 : array, shape (n2,d), optional + x2 : array-like, shape (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (n1, n2) distance matrix computed with given metric """ @@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): -- cgit v1.2.3