# -*- coding: utf-8 -*- """ Various useful functions """ # Author: Remi Flamary # # License: MIT License import multiprocessing from functools import reduce import time import numpy as np from scipy.spatial.distance import cdist import sys import warnings try: from inspect import signature except ImportError: from .externals.funcsigs import signature __time_tic_toc = time.time() def tic(): """ Python implementation of Matlab tic() function """ global __time_tic_toc __time_tic_toc = time.time() def toc(message='Elapsed time : {} s'): """ Python implementation of Matlab toc() function """ t = time.time() print(message.format(t - __time_tic_toc)) return t - __time_tic_toc def toq(): """ Python implementation of Julia toc() function """ t = time.time() return t - __time_tic_toc def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" if method.lower() in ['gaussian', 'gauss', 'rbf']: K = np.exp(-dist(x1, x2) / (2 * sigma**2)) return K def unif(n): """ return a uniform histogram of length n (simplex) Parameters ---------- n : int number of bins in the histogram Returns ------- h : np.array (n,) histogram of length n such that h_i=1/n for all i """ return np.ones((n,)) / n def clean_zeros(a, b, M): """ Remove all components with zeros weights in a and b """ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) a2 = a[a > 0] b2 = b[b > 0] return a2, b2, M2 def euclidean_distances(X, Y, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. Parameters ---------- X : {array-like}, shape (n_samples_1, n_features) Y : {array-like}, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. Returns ------- distances : {array}, shape (n_samples_1, n_samples_2) """ XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] distances = np.dot(X, Y.T) distances *= -2 distances += XX distances += YY np.maximum(distances, 0, out=distances) if X is Y: # Ensure that distances between vectors and themselves are set to 0.0. # This may not be the case due to floating point rounding errors. distances.flat[::distances.shape[0] + 1] = 0.0 return distances if squared else np.sqrt(distances, out=distances) def dist(x1, x2=None, metric='sqeuclidean'): """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist Parameters ---------- x1 : ndarray, shape (n1,d) matrix with n1 samples of size d x2 : array, shape (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str | callable, optional Name of the metric to be computed (full list in the doc of scipy), If a string, the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- M : np.array (n1,n2) distance matrix computed with given metric """ if x2 is None: x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): """Compute standard cost matrices of size (n, n) for OT problems Parameters ---------- n : int Size of the cost matrix. method : str, optional Type of loss matrix chosen from: * 'lin_square' : linear sampling between 0 and n-1, quadratic loss Returns ------- M : ndarray, shape (n1,n2) Distance matrix computed with given metric. """ res = 0 if method == 'lin_square': x = np.arange(n, dtype=np.float64).reshape((n, 1)) res = dist(x, x) return res def cost_normalization(C, norm=None): """ Apply normalization to the loss matrix Parameters ---------- C : ndarray, shape (n1, n2) The cost matrix to normalize. norm : str Type of normalization from 'median', 'max', 'log', 'loglog'. Any other value do not normalize. Returns ------- C : ndarray, shape (n1, n2) The input cost matrix normalized according to given norm. """ if norm is None: pass elif norm == "median": C /= float(np.median(C)) elif norm == "max": C /= float(np.max(C)) elif norm == "log": C = np.log(1 + C) elif norm == "loglog": C = np.log1p(np.log1p(C)) else: raise ValueError('Norm %s is not a valid option.\n' 'Valid options are:\n' 'median, max, log, loglog' % norm) return C def dots(*args): """ dots function for multiple matrix multiply """ return reduce(np.dot, args) def fun(f, q_in, q_out): """ Utility function for parmap with no serializing problems """ while True: i, x = q_in.get() if i is None: break q_out.put((i, f(x))) def parmap(f, X, nprocs=multiprocessing.cpu_count()): """ paralell map for multiprocessing (only map on windows)""" if not sys.platform.endswith('win32'): q_in = multiprocessing.Queue(1) q_out = multiprocessing.Queue() proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) for _ in range(nprocs)] for p in proc: p.daemon = True p.start() sent = [q_in.put((i, x)) for i, x in enumerate(X)] [q_in.put((None, None)) for _ in range(nprocs)] res = [q_out.get() for _ in range(len(sent))] [p.join() for p in proc] return [x for i, x in sorted(res)] else: return list(map(f, X)) def check_params(**kwargs): """check_params: check whether some parameters are missing """ missing_params = [] check = True for param in kwargs: if kwargs[param] is None: missing_params.append(param) if len(missing_params) > 0: print("POT - Warning: following necessary parameters are missing") for p in missing_params: print("\n", p) check = False return check def check_random_state(seed): """Turn seed into a np.random.RandomState instance Parameters ---------- seed : None | int | instance of RandomState If seed is None, return the RandomState singleton used by np.random. If seed is an int, return a new RandomState instance seeded with seed. If seed is already a RandomState instance, return it. Otherwise raise ValueError. """ if seed is None or seed is np.random: return np.random.mtrand._rand if isinstance(seed, (int, np.integer)): return np.random.RandomState(seed) if isinstance(seed, np.random.RandomState): return seed raise ValueError('{} cannot be used to seed a numpy.random.RandomState' ' instance'.format(seed)) class deprecated(object): """Decorator to mark a function or class as deprecated. deprecated class from scikit-learn package https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py Issue a warning when the function is called/the class is instantiated and adds a warning to the docstring. The optional extra argument will be appended to the deprecation message and the docstring. Note: to use this with the default value for extra, put in an empty of parentheses: >>> from ot.deprecation import deprecated # doctest: +SKIP >>> @deprecated() # doctest: +SKIP ... def some_function(): pass # doctest: +SKIP Parameters ---------- extra : str To be added to the deprecation messages. """ # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, # but with many changes. def __init__(self, extra=''): self.extra = extra def __call__(self, obj): """Call method Parameters ---------- obj : object """ if isinstance(obj, type): return self._decorate_class(obj) else: return self._decorate_fun(obj) def _decorate_class(self, cls): msg = "Class %s is deprecated" % cls.__name__ if self.extra: msg += "; %s" % self.extra # FIXME: we should probably reset __new__ for full generality init = cls.__init__ def wrapped(*args, **kwargs): warnings.warn(msg, category=DeprecationWarning) return init(*args, **kwargs) cls.__init__ = wrapped wrapped.__name__ = '__init__' wrapped.__doc__ = self._update_doc(init.__doc__) wrapped.deprecated_original = init return cls def _decorate_fun(self, fun): """Decorate function fun""" msg = "Function %s is deprecated" % fun.__name__ if self.extra: msg += "; %s" % self.extra def wrapped(*args, **kwargs): warnings.warn(msg, category=DeprecationWarning) return fun(*args, **kwargs) wrapped.__name__ = fun.__name__ wrapped.__dict__ = fun.__dict__ wrapped.__doc__ = self._update_doc(fun.__doc__) return wrapped def _update_doc(self, olddoc): newdoc = "DEPRECATED" if self.extra: newdoc = "%s: %s" % (newdoc, self.extra) if olddoc: newdoc = "%s\n\n%s" % (newdoc, olddoc) return newdoc def _is_deprecated(func): """Helper to check if func is wraped by our deprecated decorator""" if sys.version_info < (3, 5): raise NotImplementedError("This is only available for python3.5 " "or above") closures = getattr(func, '__closure__', []) if closures is None: closures = [] is_deprecated = ('deprecated' in ''.join([c.cell_contents for c in closures if isinstance(c.cell_contents, str)])) return is_deprecated class BaseEstimator(object): """Base class for most objects in POT Code adapted from sklearn BaseEstimator class Notes ----- All estimators should specify all the parameters that can be set at the class level in their ``__init__`` as explicit keyword arguments (no ``*args`` or ``**kwargs``). """ @classmethod def _get_param_names(cls): """Get parameter names for the estimator""" # fetch the constructor or the original constructor before # deprecation wrapping if any init = getattr(cls.__init__, 'deprecated_original', cls.__init__) if init is object.__init__: # No explicit constructor to introspect return [] # introspect the constructor arguments to find the model parameters # to represent init_signature = signature(init) # Consider the constructor parameters excluding 'self' parameters = [p for p in init_signature.parameters.values() if p.name != 'self' and p.kind != p.VAR_KEYWORD] for p in parameters: if p.kind == p.VAR_POSITIONAL: raise RuntimeError("POT estimators should always " "specify their parameters in the signature" " of their __init__ (no varargs)." " %s with constructor %s doesn't " " follow this convention." % (cls, init_signature)) # Extract and sort argument names excluding 'self' return sorted([p.name for p in parameters]) def get_params(self, deep=True): """Get parameters for this estimator. Parameters ---------- deep : bool, optional If True, will return the parameters for this estimator and contained subobjects that are estimators. Returns ------- params : mapping of string to any Parameter names mapped to their values. """ out = dict() for key in self._get_param_names(): # We need deprecation warnings to always be on in order to # catch deprecated param values. # This is set in utils/__init__.py but it gets overwritten # when running under python3 somehow. warnings.simplefilter("always", DeprecationWarning) try: with warnings.catch_warnings(record=True) as w: value = getattr(self, key, None) if len(w) and w[0].category == DeprecationWarning: # if the parameter is deprecated, don't show it continue finally: warnings.filters.pop(0) # XXX: should we rather test if instance of estimator? if deep and hasattr(value, 'get_params'): deep_items = value.get_params().items() out.update((key + '__' + k, val) for k, val in deep_items) out[key] = value return out def set_params(self, **params): """Set the parameters of this estimator. The method works on simple estimators as well as on nested objects (such as pipelines). The latter have parameters of the form ``__`` so that it's possible to update each component of a nested object. Returns ------- self """ if not params: # Simple optimisation to gain speed (inspect is slow) return self valid_params = self.get_params(deep=True) # for key, value in iteritems(params): for key, value in params.items(): split = key.split('__', 1) if len(split) > 1: # nested objects case name, sub_name = split if name not in valid_params: raise ValueError('Invalid parameter %s for estimator %s. ' 'Check the list of available parameters ' 'with `estimator.get_params().keys()`.' % (name, self)) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) else: # simple objects case if key not in valid_params: raise ValueError('Invalid parameter %s for estimator %s. ' 'Check the list of available parameters ' 'with `estimator.get_params().keys()`.' % (key, self.__class__.__name__)) setattr(self, key, value) return self class UndefinedParameter(Exception): """ Aim at raising an Exception when a undefined parameter is called """ pass