diff options
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 128 |
1 files changed, 107 insertions, 21 deletions
diff --git a/ot/utils.py b/ot/utils.py index 544c569..4dac0c5 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() @@ -41,8 +42,11 @@ def toq(): def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K @@ -52,6 +56,66 @@ def laplacian(x): return L +def list_to_array(*lst): + """ Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r""" compute the closest point (orthogonal projection) on the + generalized (n-1)-simplex of a vector v wrt. to the Euclidean + distance, thus solving: + .. math:: + \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + + s.t. \gamma^T 1= z + + \gamma\geq 0 + + If v is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Parameters + ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (n,d) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + def unif(n): """ return a uniform histogram of length n (simplex) @@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- X : {array-like}, shape (n_samples_1, n_features) Y : {array-like}, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- distances : {array}, shape (n_samples_1, n_samples_2) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + """Compute distance between samples in x1 and x2 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) + x1 : array-like, shape (n1,d) matrix with n1 samples of size d - x2 : array, shape (n2,d), optional + x2 : array-like, shape (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (n1, n2) distance matrix computed with given metric """ @@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): |