diff options
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 269 |
1 files changed, 163 insertions, 106 deletions
diff --git a/ot/utils.py b/ot/utils.py index f9911a1..c878563 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -7,7 +7,6 @@ Various useful functions # # License: MIT License -import multiprocessing from functools import reduce import time @@ -15,67 +14,127 @@ import numpy as np from scipy.spatial.distance import cdist import sys import warnings -try: - from inspect import signature -except ImportError: - from .externals.funcsigs import signature +from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() def tic(): - """ Python implementation of Matlab tic() function """ + r""" Python implementation of Matlab tic() function """ global __time_tic_toc __time_tic_toc = time.time() def toc(message='Elapsed time : {} s'): - """ Python implementation of Matlab toc() function """ + r""" Python implementation of Matlab toc() function """ t = time.time() print(message.format(t - __time_tic_toc)) return t - __time_tic_toc def toq(): - """ Python implementation of Julia toc() function """ + r""" Python implementation of Julia toc() function """ t = time.time() return t - __time_tic_toc def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): - """Compute kernel matrix""" + r"""Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K def laplacian(x): - """Compute Laplacian matrix""" + r"""Compute Laplacian matrix""" L = np.diag(np.sum(x, axis=0)) - x return L -def unif(n): - """ return a uniform histogram of length n (simplex) +def list_to_array(*lst): + r""" Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r"""Compute the closest point (orthogonal projection) on the + generalized `(n-1)`-simplex of a vector :math:`\mathbf{v}` wrt. to the Euclidean + distance, thus solving: + + .. math:: + \mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2 + + s.t. \ \gamma^T \mathbf{1} = z + + \gamma \geq 0 + + If :math:`\mathbf{v}` is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (`n`, `d`) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + +def unif(n): + r""" + Return a uniform histogram of length `n` (simplex). + Parameters + ---------- n : int number of bins in the histogram Returns ------- - h : np.array (n,) - histogram of length n such that h_i=1/n for all i - - + h : np.array (`n`,) + histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ return np.ones((n,)) / n def clean_zeros(a, b, M): - """ Remove all components with zeros weights in a and b + r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}` """ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) a2 = a[a > 0] @@ -84,55 +143,71 @@ def clean_zeros(a, b, M): def euclidean_distances(X, Y, squared=False): - """ - Considering the rows of X (and Y=X) as vectors, compute the + r""" + Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- - X : {array-like}, shape (n_samples_1, n_features) - Y : {array-like}, shape (n_samples_2, n_features) + X : array-like, shape (n_samples_1, n_features) + Y : array-like, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- - distances : {array}, shape (n_samples_1, n_samples_2) + distances : array-like, shape (`n_samples_1`, `n_samples_2`) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c + +def dist(x1, x2=None, metric='sqeuclidean', p=2): + r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` -def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) - matrix with n1 samples of size d - x2 : array, shape (n2,d), optional - matrix with n2 samples of size d (if None then x2=x1) + x1 : array-like, shape (n1,d) + matrix with `n1` samples of size `d` + x2 : array-like, shape (n2,d), optional + matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (`n1`, `n2`) distance matrix computed with given metric """ @@ -140,11 +215,17 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric, p=p) def dist0(n, method='lin_square'): - """Compute standard cost matrices of size (n, n) for OT problems + r"""Compute standard cost matrices of size (`n`, `n`) for OT problems Parameters ---------- @@ -153,11 +234,11 @@ def dist0(n, method='lin_square'): method : str, optional Type of loss matrix chosen from: - * 'lin_square' : linear sampling between 0 and n-1, quadratic loss + * 'lin_square' : linear sampling between 0 and `n-1`, quadratic loss Returns ------- - M : ndarray, shape (n1,n2) + M : ndarray, shape (`n1`, `n2`) Distance matrix computed with given metric. """ res = 0 @@ -168,7 +249,7 @@ def dist0(n, method='lin_square'): def cost_normalization(C, norm=None): - """ Apply normalization to the loss matrix + r""" Apply normalization to the loss matrix Parameters ---------- @@ -180,7 +261,7 @@ def cost_normalization(C, norm=None): Returns ------- - C : ndarray, shape (n1, n2) + C : ndarray, shape (`n1`, `n2`) The input cost matrix normalized according to given norm. """ @@ -202,23 +283,23 @@ def cost_normalization(C, norm=None): def dots(*args): - """ dots function for multiple matrix multiply """ + r""" dots function for multiple matrix multiply """ return reduce(np.dot, args) def label_normalization(y, start=0): - """ Transform labels to start at a given value + r""" Transform labels to start at a given value Parameters ---------- y : array-like, shape (n, ) The vector of labels to be normalized. start : int - Desired value for the smallest label in y (default=0) + Desired value for the smallest label in :math:`\mathbf{y}` (default=0) Returns ------- - y : array-like, shape (n1, ) + y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ @@ -228,42 +309,15 @@ def label_normalization(y, start=0): return y -def fun(f, q_in, q_out): - """ Utility function for parmap with no serializing problems """ - while True: - i, x = q_in.get() - if i is None: - break - q_out.put((i, f(x))) - - -def parmap(f, X, nprocs=multiprocessing.cpu_count()): - """ paralell map for multiprocessing (only map on windows)""" - - if not sys.platform.endswith('win32'): - - q_in = multiprocessing.Queue(1) - q_out = multiprocessing.Queue() - - proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) - for _ in range(nprocs)] - for p in proc: - p.daemon = True - p.start() - - sent = [q_in.put((i, x)) for i, x in enumerate(X)] - [q_in.put((None, None)) for _ in range(nprocs)] - res = [q_out.get() for _ in range(len(sent))] - - [p.join() for p in proc] - - return [x for i, x in sorted(res)] - else: - return list(map(f, X)) +def parmap(f, X, nprocs="default"): + r""" parallel map for multiprocessing. + The function has been deprecated and only performs a regular map. + """ + return list(map(f, X)) def check_params(**kwargs): - """check_params: check whether some parameters are missing + r"""check_params: check whether some parameters are missing """ missing_params = [] @@ -284,14 +338,14 @@ def check_params(**kwargs): def check_random_state(seed): - """Turn seed into a np.random.RandomState instance + r"""Turn `seed` into a np.random.RandomState instance Parameters ---------- seed : None | int | instance of RandomState - If seed is None, return the RandomState singleton used by np.random. - If seed is an int, return a new RandomState instance seeded with seed. - If seed is already a RandomState instance, return it. + If `seed` is None, return the RandomState singleton used by np.random. + If `seed` is an int, return a new RandomState instance seeded with `seed`. + If `seed` is already a RandomState instance, return it. Otherwise raise ValueError. """ if seed is None or seed is np.random: @@ -305,18 +359,21 @@ def check_random_state(seed): class deprecated(object): - """Decorator to mark a function or class as deprecated. + r"""Decorator to mark a function or class as deprecated. deprecated class from scikit-learn package https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py Issue a warning when the function is called/the class is instantiated and adds a warning to the docstring. The optional extra argument will be appended to the deprecation message - and the docstring. Note: to use this with the default value for extra, put - in an empty of parentheses: - >>> from ot.deprecation import deprecated # doctest: +SKIP - >>> @deprecated() # doctest: +SKIP - ... def some_function(): pass # doctest: +SKIP + and the docstring. + + .. note:: + To use this with the default value for extra, use empty parentheses: + + >>> from ot.deprecation import deprecated # doctest: +SKIP + >>> @deprecated() # doctest: +SKIP + ... def some_function(): pass # doctest: +SKIP Parameters ---------- @@ -331,7 +388,7 @@ class deprecated(object): self.extra = extra def __call__(self, obj): - """Call method + r"""Call method Parameters ---------- obj : object @@ -362,7 +419,7 @@ class deprecated(object): return cls def _decorate_fun(self, fun): - """Decorate function fun""" + r"""Decorate function fun""" msg = "Function %s is deprecated" % fun.__name__ if self.extra: @@ -388,7 +445,7 @@ class deprecated(object): def _is_deprecated(func): - """Helper to check if func is wraped by our deprecated decorator""" + r"""Helper to check if func is wraped by our deprecated decorator""" if sys.version_info < (3, 5): raise NotImplementedError("This is only available for python3.5 " "or above") @@ -402,7 +459,7 @@ def _is_deprecated(func): class BaseEstimator(object): - """Base class for most objects in POT + r"""Base class for most objects in POT Code adapted from sklearn BaseEstimator class @@ -415,7 +472,7 @@ class BaseEstimator(object): @classmethod def _get_param_names(cls): - """Get parameter names for the estimator""" + r"""Get parameter names for the estimator""" # fetch the constructor or the original constructor before # deprecation wrapping if any @@ -442,7 +499,7 @@ class BaseEstimator(object): return sorted([p.name for p in parameters]) def get_params(self, deep=True): - """Get parameters for this estimator. + r"""Get parameters for this estimator. Parameters ---------- @@ -479,7 +536,7 @@ class BaseEstimator(object): return out def set_params(self, **params): - """Set the parameters of this estimator. + r"""Set the parameters of this estimator. The method works on simple estimators as well as on nested objects (such as pipelines). The latter have parameters of the form @@ -519,7 +576,7 @@ class BaseEstimator(object): class UndefinedParameter(Exception): - """ + r""" Aim at raising an Exception when a undefined parameter is called """ |