diff options
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 498 |
1 files changed, 498 insertions, 0 deletions
diff --git a/ot/utils.py b/ot/utils.py new file mode 100644 index 0000000..b71458b --- /dev/null +++ b/ot/utils.py @@ -0,0 +1,498 @@ +# -*- coding: utf-8 -*- +""" +Various useful functions +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import multiprocessing +from functools import reduce +import time + +import numpy as np +from scipy.spatial.distance import cdist +import sys +import warnings +try: + from inspect import signature +except ImportError: + from .externals.funcsigs import signature + +__time_tic_toc = time.time() + + +def tic(): + """ Python implementation of Matlab tic() function """ + global __time_tic_toc + __time_tic_toc = time.time() + + +def toc(message='Elapsed time : {} s'): + """ Python implementation of Matlab toc() function """ + t = time.time() + print(message.format(t - __time_tic_toc)) + return t - __time_tic_toc + + +def toq(): + """ Python implementation of Julia toc() function """ + t = time.time() + return t - __time_tic_toc + + +def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): + """Compute kernel matrix""" + if method.lower() in ['gaussian', 'gauss', 'rbf']: + K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + return K + + +def unif(n): + """ return a uniform histogram of length n (simplex) + + Parameters + ---------- + + n : int + number of bins in the histogram + + Returns + ------- + h : np.array (n,) + histogram of length n such that h_i=1/n for all i + + + """ + return np.ones((n,)) / n + + +def clean_zeros(a, b, M): + """ Remove all components with zeros weights in a and b + """ + M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) + a2 = a[a > 0] + b2 = b[b > 0] + return a2, b2, M2 + + +def euclidean_distances(X, Y, squared=False): + """ + Considering the rows of X (and Y=X) as vectors, compute the + distance matrix between each pair of vectors. + Parameters + ---------- + X : {array-like}, shape (n_samples_1, n_features) + Y : {array-like}, shape (n_samples_2, n_features) + squared : boolean, optional + Return squared Euclidean distances. + Returns + ------- + distances : {array}, shape (n_samples_1, n_samples_2) + """ + XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] + YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] + distances = np.dot(X, Y.T) + distances *= -2 + distances += XX + distances += YY + np.maximum(distances, 0, out=distances) + if X is Y: + # Ensure that distances between vectors and themselves are set to 0.0. + # This may not be the case due to floating point rounding errors. + distances.flat[::distances.shape[0] + 1] = 0.0 + return distances if squared else np.sqrt(distances, out=distances) + + +def dist(x1, x2=None, metric='sqeuclidean'): + """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + + Parameters + ---------- + + x1 : ndarray, shape (n1,d) + matrix with n1 samples of size d + x2 : array, shape (n2,d), optional + matrix with n2 samples of size d (if None then x2=x1) + metric : str | callable, optional + Name of the metric to be computed (full list in the doc of scipy), If a string, + the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', + 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', + 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. + + + Returns + ------- + + M : np.array (n1,n2) + distance matrix computed with given metric + + """ + if x2 is None: + x2 = x1 + if metric == "sqeuclidean": + return euclidean_distances(x1, x2, squared=True) + return cdist(x1, x2, metric=metric) + + +def dist0(n, method='lin_square'): + """Compute standard cost matrices of size (n, n) for OT problems + + Parameters + ---------- + n : int + Size of the cost matrix. + method : str, optional + Type of loss matrix chosen from: + + * 'lin_square' : linear sampling between 0 and n-1, quadratic loss + + Returns + ------- + M : ndarray, shape (n1,n2) + Distance matrix computed with given metric. + """ + res = 0 + if method == 'lin_square': + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + res = dist(x, x) + return res + + +def cost_normalization(C, norm=None): + """ Apply normalization to the loss matrix + + Parameters + ---------- + C : ndarray, shape (n1, n2) + The cost matrix to normalize. + norm : str + Type of normalization from 'median', 'max', 'log', 'loglog'. Any + other value do not normalize. + + Returns + ------- + C : ndarray, shape (n1, n2) + The input cost matrix normalized according to given norm. + """ + + if norm is None: + pass + elif norm == "median": + C /= float(np.median(C)) + elif norm == "max": + C /= float(np.max(C)) + elif norm == "log": + C = np.log(1 + C) + elif norm == "loglog": + C = np.log1p(np.log1p(C)) + else: + raise ValueError('Norm %s is not a valid option.\n' + 'Valid options are:\n' + 'median, max, log, loglog' % norm) + return C + + +def dots(*args): + """ dots function for multiple matrix multiply """ + return reduce(np.dot, args) + + +def fun(f, q_in, q_out): + """ Utility function for parmap with no serializing problems """ + while True: + i, x = q_in.get() + if i is None: + break + q_out.put((i, f(x))) + + +def parmap(f, X, nprocs=multiprocessing.cpu_count()): + """ paralell map for multiprocessing (only map on windows)""" + + if not sys.platform.endswith('win32'): + + q_in = multiprocessing.Queue(1) + q_out = multiprocessing.Queue() + + proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) + for _ in range(nprocs)] + for p in proc: + p.daemon = True + p.start() + + sent = [q_in.put((i, x)) for i, x in enumerate(X)] + [q_in.put((None, None)) for _ in range(nprocs)] + res = [q_out.get() for _ in range(len(sent))] + + [p.join() for p in proc] + + return [x for i, x in sorted(res)] + else: + return list(map(f, X)) + + +def check_params(**kwargs): + """check_params: check whether some parameters are missing + """ + + missing_params = [] + check = True + + for param in kwargs: + if kwargs[param] is None: + missing_params.append(param) + + if len(missing_params) > 0: + print("POT - Warning: following necessary parameters are missing") + for p in missing_params: + print("\n", p) + + check = False + + return check + + +def check_random_state(seed): + """Turn seed into a np.random.RandomState instance + + Parameters + ---------- + seed : None | int | instance of RandomState + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + """ + if seed is None or seed is np.random: + return np.random.mtrand._rand + if isinstance(seed, (int, np.integer)): + return np.random.RandomState(seed) + if isinstance(seed, np.random.RandomState): + return seed + raise ValueError('{} cannot be used to seed a numpy.random.RandomState' + ' instance'.format(seed)) + + +class deprecated(object): + """Decorator to mark a function or class as deprecated. + + deprecated class from scikit-learn package + https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py + Issue a warning when the function is called/the class is instantiated and + adds a warning to the docstring. + The optional extra argument will be appended to the deprecation message + and the docstring. Note: to use this with the default value for extra, put + in an empty of parentheses: + >>> from ot.deprecation import deprecated # doctest: +SKIP + >>> @deprecated() # doctest: +SKIP + ... def some_function(): pass # doctest: +SKIP + + Parameters + ---------- + extra : str + To be added to the deprecation messages. + """ + + # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, + # but with many changes. + + def __init__(self, extra=''): + self.extra = extra + + def __call__(self, obj): + """Call method + Parameters + ---------- + obj : object + """ + if isinstance(obj, type): + return self._decorate_class(obj) + else: + return self._decorate_fun(obj) + + def _decorate_class(self, cls): + msg = "Class %s is deprecated" % cls.__name__ + if self.extra: + msg += "; %s" % self.extra + + # FIXME: we should probably reset __new__ for full generality + init = cls.__init__ + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return init(*args, **kwargs) + + cls.__init__ = wrapped + + wrapped.__name__ = '__init__' + wrapped.__doc__ = self._update_doc(init.__doc__) + wrapped.deprecated_original = init + + return cls + + def _decorate_fun(self, fun): + """Decorate function fun""" + + msg = "Function %s is deprecated" % fun.__name__ + if self.extra: + msg += "; %s" % self.extra + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return fun(*args, **kwargs) + + wrapped.__name__ = fun.__name__ + wrapped.__dict__ = fun.__dict__ + wrapped.__doc__ = self._update_doc(fun.__doc__) + + return wrapped + + def _update_doc(self, olddoc): + newdoc = "DEPRECATED" + if self.extra: + newdoc = "%s: %s" % (newdoc, self.extra) + if olddoc: + newdoc = "%s\n\n%s" % (newdoc, olddoc) + return newdoc + + +def _is_deprecated(func): + """Helper to check if func is wraped by our deprecated decorator""" + if sys.version_info < (3, 5): + raise NotImplementedError("This is only available for python3.5 " + "or above") + closures = getattr(func, '__closure__', []) + if closures is None: + closures = [] + is_deprecated = ('deprecated' in ''.join([c.cell_contents + for c in closures + if isinstance(c.cell_contents, str)])) + return is_deprecated + + +class BaseEstimator(object): + """Base class for most objects in POT + + Code adapted from sklearn BaseEstimator class + + Notes + ----- + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). + """ + + @classmethod + def _get_param_names(cls): + """Get parameter names for the estimator""" + + # fetch the constructor or the original constructor before + # deprecation wrapping if any + init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + if init is object.__init__: + # No explicit constructor to introspect + return [] + + # introspect the constructor arguments to find the model parameters + # to represent + init_signature = signature(init) + # Consider the constructor parameters excluding 'self' + parameters = [p for p in init_signature.parameters.values() + if p.name != 'self' and p.kind != p.VAR_KEYWORD] + for p in parameters: + if p.kind == p.VAR_POSITIONAL: + raise RuntimeError("POT estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." + % (cls, init_signature)) + # Extract and sort argument names excluding 'self' + return sorted([p.name for p in parameters]) + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Parameters + ---------- + deep : bool, optional + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : mapping of string to any + Parameter names mapped to their values. + """ + out = dict() + for key in self._get_param_names(): + # We need deprecation warnings to always be on in order to + # catch deprecated param values. + # This is set in utils/__init__.py but it gets overwritten + # when running under python3 somehow. + warnings.simplefilter("always", DeprecationWarning) + try: + with warnings.catch_warnings(record=True) as w: + value = getattr(self, key, None) + if len(w) and w[0].category == DeprecationWarning: + # if the parameter is deprecated, don't show it + continue + finally: + warnings.filters.pop(0) + + # XXX: should we rather test if instance of estimator? + if deep and hasattr(value, 'get_params'): + deep_items = value.get_params().items() + out.update((key + '__' + k, val) for k, val in deep_items) + out[key] = value + return out + + def set_params(self, **params): + """Set the parameters of this estimator. + + The method works on simple estimators as well as on nested objects + (such as pipelines). The latter have parameters of the form + ``<component>__<parameter>`` so that it's possible to update each + component of a nested object. + + Returns + ------- + self + """ + if not params: + # Simple optimisation to gain speed (inspect is slow) + return self + valid_params = self.get_params(deep=True) + # for key, value in iteritems(params): + for key, value in params.items(): + split = key.split('__', 1) + if len(split) > 1: + # nested objects case + name, sub_name = split + if name not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (name, self)) + sub_object = valid_params[name] + sub_object.set_params(**{sub_name: value}) + else: + # simple objects case + if key not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (key, self.__class__.__name__)) + setattr(self, key, value) + return self + + +class UndefinedParameter(Exception): + """ + Aim at raising an Exception when a undefined parameter is called + + """ + pass |