diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-08-28 10:43:31 +0200 |
---|---|---|
committer | Slasnista <stan.chambon@gmail.com> | 2017-08-28 10:43:31 +0200 |
commit | 55840f6bccadd79caf722d86f06da857e3045453 (patch) | |
tree | 833e706b9109579b68a94956a05d20e6fa773e49 /ot/utils.py | |
parent | 892d7ce10912de3d07692b5083578932a32ab61f (diff) |
move no da objects into utils.py
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 219 |
1 files changed, 219 insertions, 0 deletions
diff --git a/ot/utils.py b/ot/utils.py index 2b2f8b3..29ad536 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -13,6 +13,9 @@ import time import numpy as np from scipy.spatial.distance import cdist +import sys +import warnings + __time_tic_toc = time.time() @@ -163,3 +166,219 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()): [p.join() for p in proc] return [x for i, x in sorted(res)] + + +class deprecated(object): + """Decorator to mark a function or class as deprecated. + + deprecated class from scikit-learn package + https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py + Issue a warning when the function is called/the class is instantiated and + adds a warning to the docstring. + The optional extra argument will be appended to the deprecation message + and the docstring. Note: to use this with the default value for extra, put + in an empty of parentheses: + >>> from ot.deprecation import deprecated + >>> @deprecated() + ... def some_function(): pass + + Parameters + ---------- + extra : string + to be added to the deprecation messages + """ + + # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, + # but with many changes. + + def __init__(self, extra=''): + self.extra = extra + + def __call__(self, obj): + """Call method + Parameters + ---------- + obj : object + """ + if isinstance(obj, type): + return self._decorate_class(obj) + else: + return self._decorate_fun(obj) + + def _decorate_class(self, cls): + msg = "Class %s is deprecated" % cls.__name__ + if self.extra: + msg += "; %s" % self.extra + + # FIXME: we should probably reset __new__ for full generality + init = cls.__init__ + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return init(*args, **kwargs) + + cls.__init__ = wrapped + + wrapped.__name__ = '__init__' + wrapped.__doc__ = self._update_doc(init.__doc__) + wrapped.deprecated_original = init + + return cls + + def _decorate_fun(self, fun): + """Decorate function fun""" + + msg = "Function %s is deprecated" % fun.__name__ + if self.extra: + msg += "; %s" % self.extra + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return fun(*args, **kwargs) + + wrapped.__name__ = fun.__name__ + wrapped.__dict__ = fun.__dict__ + wrapped.__doc__ = self._update_doc(fun.__doc__) + + return wrapped + + def _update_doc(self, olddoc): + newdoc = "DEPRECATED" + if self.extra: + newdoc = "%s: %s" % (newdoc, self.extra) + if olddoc: + newdoc = "%s\n\n%s" % (newdoc, olddoc) + return newdoc + + +def _is_deprecated(func): + """Helper to check if func is wraped by our deprecated decorator""" + if sys.version_info < (3, 5): + raise NotImplementedError("This is only available for python3.5 " + "or above") + closures = getattr(func, '__closure__', []) + if closures is None: + closures = [] + is_deprecated = ('deprecated' in ''.join([c.cell_contents + for c in closures + if isinstance(c.cell_contents, str)])) + return is_deprecated + + +class BaseEstimator(object): + """Base class for most objects in POT + adapted from sklearn BaseEstimator class + + Notes + ----- + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). + """ + + @classmethod + def _get_param_names(cls): + """Get parameter names for the estimator""" + try: + from inspect import signature + except ImportError: + from .externals.funcsigs import signature + # fetch the constructor or the original constructor before + # deprecation wrapping if any + init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + if init is object.__init__: + # No explicit constructor to introspect + return [] + + # introspect the constructor arguments to find the model parameters + # to represent + init_signature = signature(init) + # Consider the constructor parameters excluding 'self' + parameters = [p for p in init_signature.parameters.values() + if p.name != 'self' and p.kind != p.VAR_KEYWORD] + for p in parameters: + if p.kind == p.VAR_POSITIONAL: + raise RuntimeError("POT estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." + % (cls, init_signature)) + # Extract and sort argument names excluding 'self' + return sorted([p.name for p in parameters]) + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Parameters + ---------- + deep : boolean, optional + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : mapping of string to any + Parameter names mapped to their values. + """ + out = dict() + for key in self._get_param_names(): + # We need deprecation warnings to always be on in order to + # catch deprecated param values. + # This is set in utils/__init__.py but it gets overwritten + # when running under python3 somehow. + warnings.simplefilter("always", DeprecationWarning) + try: + with warnings.catch_warnings(record=True) as w: + value = getattr(self, key, None) + if len(w) and w[0].category == DeprecationWarning: + # if the parameter is deprecated, don't show it + continue + finally: + warnings.filters.pop(0) + + # XXX: should we rather test if instance of estimator? + if deep and hasattr(value, 'get_params'): + deep_items = value.get_params().items() + out.update((key + '__' + k, val) for k, val in deep_items) + out[key] = value + return out + + def set_params(self, **params): + """Set the parameters of this estimator. + + The method works on simple estimators as well as on nested objects + (such as pipelines). The latter have parameters of the form + ``<component>__<parameter>`` so that it's possible to update each + component of a nested object. + + Returns + ------- + self + """ + if not params: + # Simple optimisation to gain speed (inspect is slow) + return self + valid_params = self.get_params(deep=True) + # for key, value in iteritems(params): + for key, value in params.items(): + split = key.split('__', 1) + if len(split) > 1: + # nested objects case + name, sub_name = split + if name not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (name, self)) + sub_object = valid_params[name] + sub_object.set_params(**{sub_name: value}) + else: + # simple objects case + if key not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (key, self.__class__.__name__)) + setattr(self, key, value) + return self |