diff options
author | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-05 15:30:50 +0900 |
---|---|---|
committer | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-05 15:30:50 +0900 |
commit | 13dfb3ddbbd8926b4751b82dd41c5570253b1f07 (patch) | |
tree | b28098e98640c64483a599103e2fdb5df46d2c79 /ot/utils.py | |
parent | 185eb3e2ef34b5ce6b8f90a28a5bcc78432b7fd3 (diff) | |
parent | 16697047eff9326a0ecb483317c13a854a3d3a71 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 346 |
1 files changed, 317 insertions, 29 deletions
diff --git a/ot/utils.py b/ot/utils.py index 7ad7637..31a002b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -2,36 +2,50 @@ """ Various function that can be usefull """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import multiprocessing +from functools import reduce +import time + import numpy as np from scipy.spatial.distance import cdist -import multiprocessing +import sys +import warnings + + +__time_tic_toc = time.time() -import time -__time_tic_toc=time.time() def tic(): """ Python implementation of Matlab tic() function """ global __time_tic_toc - __time_tic_toc=time.time() + __time_tic_toc = time.time() + def toc(message='Elapsed time : {} s'): """ Python implementation of Matlab toc() function """ - t=time.time() - print(message.format(t-__time_tic_toc)) - return t-__time_tic_toc + t = time.time() + print(message.format(t - __time_tic_toc)) + return t - __time_tic_toc + def toq(): """ Python implementation of Julia toc() function """ - t=time.time() - return t-__time_tic_toc + t = time.time() + return t - __time_tic_toc -def kernel(x1,x2,method='gaussian',sigma=1,**kwargs): +def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" - if method.lower() in ['gaussian','gauss','rbf']: - K=np.exp(-dist(x1,x2)/(2*sigma**2)) + if method.lower() in ['gaussian', 'gauss', 'rbf']: + K = np.exp(-dist(x1, x2) / (2 * sigma**2)) return K + def unif(n): """ return a uniform histogram of length n (simplex) @@ -48,17 +62,19 @@ def unif(n): """ - return np.ones((n,))/n + return np.ones((n,)) / n + -def clean_zeros(a,b,M): - """ Remove all components with zeros weights in a and b +def clean_zeros(a, b, M): + """ Remove all components with zeros weights in a and b """ - M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd) - a2=a[a>0] - b2=b[b>0] - return a2,b2,M2 + M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) + a2 = a[a > 0] + b2 = b[b > 0] + return a2, b2, M2 -def dist(x1,x2=None,metric='sqeuclidean'): + +def dist(x1, x2=None, metric='sqeuclidean'): """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist Parameters @@ -84,12 +100,12 @@ def dist(x1,x2=None,metric='sqeuclidean'): """ if x2 is None: - x2=x1 + x2 = x1 - return cdist(x1,x2,metric=metric) + return cdist(x1, x2, metric=metric) -def dist0(n,method='lin_square'): +def dist0(n, method='lin_square'): """Compute standard cost matrices of size (n,n) for OT problems Parameters @@ -111,16 +127,50 @@ def dist0(n,method='lin_square'): """ - res=0 - if method=='lin_square': - x=np.arange(n,dtype=np.float64).reshape((n,1)) - res=dist(x,x) + res = 0 + if method == 'lin_square': + x = np.arange(n, dtype=np.float64).reshape((n, 1)) + res = dist(x, x) return res +def cost_normalization(C, norm=None): + """ Apply normalization to the loss matrix + + + Parameters + ---------- + C : np.array (n1, n2) + The cost matrix to normalize. + norm : str + type of normalization from 'median','max','log','loglog'. Any other + value do not normalize. + + + Returns + ------- + + C : np.array (n1, n2) + The input cost matrix normalized according to given norm. + + """ + + if norm == "median": + C /= float(np.median(C)) + elif norm == "max": + C /= float(np.max(C)) + elif norm == "log": + C = np.log(1 + C) + elif norm == "loglog": + C = np.log(1 + np.log(1 + C)) + + return C + + def dots(*args): """ dots function for multiple matrix multiply """ - return reduce(np.dot,args) + return reduce(np.dot, args) + def fun(f, q_in, q_out): """ Utility function for parmap with no serializing problems """ @@ -130,6 +180,7 @@ def fun(f, q_in, q_out): break q_out.put((i, f(x))) + def parmap(f, X, nprocs=multiprocessing.cpu_count()): """ paralell map for multiprocessing """ q_in = multiprocessing.Queue(1) @@ -147,4 +198,241 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()): [p.join() for p in proc] - return [x for i, x in sorted(res)]
\ No newline at end of file + return [x for i, x in sorted(res)] + + +def check_params(**kwargs): + """check_params: check whether some parameters are missing + """ + + missing_params = [] + check = True + + for param in kwargs: + if kwargs[param] is None: + missing_params.append(param) + + if len(missing_params) > 0: + print("POT - Warning: following necessary parameters are missing") + for p in missing_params: + print("\n", p) + + check = False + + return check + + +class deprecated(object): + """Decorator to mark a function or class as deprecated. + + deprecated class from scikit-learn package + https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py + Issue a warning when the function is called/the class is instantiated and + adds a warning to the docstring. + The optional extra argument will be appended to the deprecation message + and the docstring. Note: to use this with the default value for extra, put + in an empty of parentheses: + >>> from ot.deprecation import deprecated + >>> @deprecated() + ... def some_function(): pass + + Parameters + ---------- + extra : string + to be added to the deprecation messages + """ + + # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, + # but with many changes. + + def __init__(self, extra=''): + self.extra = extra + + def __call__(self, obj): + """Call method + Parameters + ---------- + obj : object + """ + if isinstance(obj, type): + return self._decorate_class(obj) + else: + return self._decorate_fun(obj) + + def _decorate_class(self, cls): + msg = "Class %s is deprecated" % cls.__name__ + if self.extra: + msg += "; %s" % self.extra + + # FIXME: we should probably reset __new__ for full generality + init = cls.__init__ + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return init(*args, **kwargs) + + cls.__init__ = wrapped + + wrapped.__name__ = '__init__' + wrapped.__doc__ = self._update_doc(init.__doc__) + wrapped.deprecated_original = init + + return cls + + def _decorate_fun(self, fun): + """Decorate function fun""" + + msg = "Function %s is deprecated" % fun.__name__ + if self.extra: + msg += "; %s" % self.extra + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return fun(*args, **kwargs) + + wrapped.__name__ = fun.__name__ + wrapped.__dict__ = fun.__dict__ + wrapped.__doc__ = self._update_doc(fun.__doc__) + + return wrapped + + def _update_doc(self, olddoc): + newdoc = "DEPRECATED" + if self.extra: + newdoc = "%s: %s" % (newdoc, self.extra) + if olddoc: + newdoc = "%s\n\n%s" % (newdoc, olddoc) + return newdoc + + +def _is_deprecated(func): + """Helper to check if func is wraped by our deprecated decorator""" + if sys.version_info < (3, 5): + raise NotImplementedError("This is only available for python3.5 " + "or above") + closures = getattr(func, '__closure__', []) + if closures is None: + closures = [] + is_deprecated = ('deprecated' in ''.join([c.cell_contents + for c in closures + if isinstance(c.cell_contents, str)])) + return is_deprecated + + +class BaseEstimator(object): + """Base class for most objects in POT + adapted from sklearn BaseEstimator class + + Notes + ----- + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). + """ + + @classmethod + def _get_param_names(cls): + """Get parameter names for the estimator""" + try: + from inspect import signature + except ImportError: + from .externals.funcsigs import signature + # fetch the constructor or the original constructor before + # deprecation wrapping if any + init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + if init is object.__init__: + # No explicit constructor to introspect + return [] + + # introspect the constructor arguments to find the model parameters + # to represent + init_signature = signature(init) + # Consider the constructor parameters excluding 'self' + parameters = [p for p in init_signature.parameters.values() + if p.name != 'self' and p.kind != p.VAR_KEYWORD] + for p in parameters: + if p.kind == p.VAR_POSITIONAL: + raise RuntimeError("POT estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." + % (cls, init_signature)) + # Extract and sort argument names excluding 'self' + return sorted([p.name for p in parameters]) + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Parameters + ---------- + deep : boolean, optional + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : mapping of string to any + Parameter names mapped to their values. + """ + out = dict() + for key in self._get_param_names(): + # We need deprecation warnings to always be on in order to + # catch deprecated param values. + # This is set in utils/__init__.py but it gets overwritten + # when running under python3 somehow. + warnings.simplefilter("always", DeprecationWarning) + try: + with warnings.catch_warnings(record=True) as w: + value = getattr(self, key, None) + if len(w) and w[0].category == DeprecationWarning: + # if the parameter is deprecated, don't show it + continue + finally: + warnings.filters.pop(0) + + # XXX: should we rather test if instance of estimator? + if deep and hasattr(value, 'get_params'): + deep_items = value.get_params().items() + out.update((key + '__' + k, val) for k, val in deep_items) + out[key] = value + return out + + def set_params(self, **params): + """Set the parameters of this estimator. + + The method works on simple estimators as well as on nested objects + (such as pipelines). The latter have parameters of the form + ``<component>__<parameter>`` so that it's possible to update each + component of a nested object. + + Returns + ------- + self + """ + if not params: + # Simple optimisation to gain speed (inspect is slow) + return self + valid_params = self.get_params(deep=True) + # for key, value in iteritems(params): + for key, value in params.items(): + split = key.split('__', 1) + if len(split) > 1: + # nested objects case + name, sub_name = split + if name not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (name, self)) + sub_object = valid_params[name] + sub_object.set_params(**{sub_name: value}) + else: + # simple objects case + if key not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (key, self.__class__.__name__)) + setattr(self, key, value) + return self |