From 55840f6bccadd79caf722d86f06da857e3045453 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Mon, 28 Aug 2017 10:43:31 +0200 Subject: move no da objects into utils.py --- README.md | 2 + ot/da.py | 136 +-------------------------------- ot/deprecation.py | 103 ------------------------- ot/utils.py | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 222 insertions(+), 238 deletions(-) delete mode 100644 ot/deprecation.py diff --git a/README.md b/README.md index 7a65106..27b4643 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,8 @@ The contributors to this library are: * [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) * [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation) * [Léo Gautheron](https://github.com/aje) (GPU implementation) +* [Nathalie Gayraud]() +* [Stanislas Chambon](https://slasnista.github.io/) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/ot/da.py b/ot/da.py index 5a34979..8c62669 100644 --- a/ot/da.py +++ b/ot/da.py @@ -10,14 +10,13 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np -import warnings from .bregman import sinkhorn from .lp import emd from .utils import unif, dist, kernel +from .utils import deprecated, BaseEstimator from .optim import cg from .optim import gcg -from .deprecation import deprecated def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -936,139 +935,6 @@ class OTDA_mapping_kernel(OTDA_mapping_linear): print("Warning, model not fitted yet, returning None") return None -############################################################################## -# proposal -############################################################################## - - -# adapted from sklearn - -class BaseEstimator(object): - """Base class for all estimators in scikit-learn - Notes - ----- - All estimators should specify all the parameters that can be set - at the class level in their ``__init__`` as explicit keyword - arguments (no ``*args`` or ``**kwargs``). - """ - - @classmethod - def _get_param_names(cls): - """Get parameter names for the estimator""" - try: - from inspect import signature - except ImportError: - from .externals.funcsigs import signature - # fetch the constructor or the original constructor before - # deprecation wrapping if any - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) - if init is object.__init__: - # No explicit constructor to introspect - return [] - - # introspect the constructor arguments to find the model parameters - # to represent - init_signature = signature(init) - # Consider the constructor parameters excluding 'self' - parameters = [p for p in init_signature.parameters.values() - if p.name != 'self' and p.kind != p.VAR_KEYWORD] - for p in parameters: - if p.kind == p.VAR_POSITIONAL: - raise RuntimeError("scikit-learn estimators should always " - "specify their parameters in the signature" - " of their __init__ (no varargs)." - " %s with constructor %s doesn't " - " follow this convention." - % (cls, init_signature)) - # Extract and sort argument names excluding 'self' - return sorted([p.name for p in parameters]) - - def get_params(self, deep=True): - """Get parameters for this estimator. - - Parameters - ---------- - deep : boolean, optional - If True, will return the parameters for this estimator and - contained subobjects that are estimators. - - Returns - ------- - params : mapping of string to any - Parameter names mapped to their values. - """ - out = dict() - for key in self._get_param_names(): - # We need deprecation warnings to always be on in order to - # catch deprecated param values. - # This is set in utils/__init__.py but it gets overwritten - # when running under python3 somehow. - warnings.simplefilter("always", DeprecationWarning) - try: - with warnings.catch_warnings(record=True) as w: - value = getattr(self, key, None) - if len(w) and w[0].category == DeprecationWarning: - # if the parameter is deprecated, don't show it - continue - finally: - warnings.filters.pop(0) - - # XXX: should we rather test if instance of estimator? - if deep and hasattr(value, 'get_params'): - deep_items = value.get_params().items() - out.update((key + '__' + k, val) for k, val in deep_items) - out[key] = value - return out - - def set_params(self, **params): - """Set the parameters of this estimator. - - The method works on simple estimators as well as on nested objects - (such as pipelines). The latter have parameters of the form - ``__`` so that it's possible to update each - component of a nested object. - - Returns - ------- - self - """ - if not params: - # Simple optimisation to gain speed (inspect is slow) - return self - valid_params = self.get_params(deep=True) - # for key, value in iteritems(params): - for key, value in params.items(): - split = key.split('__', 1) - if len(split) > 1: - # nested objects case - name, sub_name = split - if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) - sub_object = valid_params[name] - sub_object.set_params(**{sub_name: value}) - else: - # simple objects case - if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) - setattr(self, key, value) - return self - - def __repr__(self): - from sklearn.base import _pprint - class_name = self.__class__.__name__ - return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False), - offset=len(class_name),),) - - # __getstate__ and __setstate__ are omitted because they only contain - # conditionals that are not satisfied by our objects (e.g., - # ``if type(self).__module__.startswith('sklearn.')``. - def distribution_estimation_uniform(X): """estimates a uniform distribution from an array of samples X diff --git a/ot/deprecation.py b/ot/deprecation.py deleted file mode 100644 index 2b16427..0000000 --- a/ot/deprecation.py +++ /dev/null @@ -1,103 +0,0 @@ -""" - deprecated class from scikit-learn package - https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py -""" - -import sys -import warnings - -__all__ = ["deprecated", ] - - -class deprecated(object): - """Decorator to mark a function or class as deprecated. - Issue a warning when the function is called/the class is instantiated and - adds a warning to the docstring. - The optional extra argument will be appended to the deprecation message - and the docstring. Note: to use this with the default value for extra, put - in an empty of parentheses: - >>> from ot.deprecation import deprecated - >>> @deprecated() - ... def some_function(): pass - - Parameters - ---------- - extra : string - to be added to the deprecation messages - """ - - # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, - # but with many changes. - - def __init__(self, extra=''): - self.extra = extra - - def __call__(self, obj): - """Call method - Parameters - ---------- - obj : object - """ - if isinstance(obj, type): - return self._decorate_class(obj) - else: - return self._decorate_fun(obj) - - def _decorate_class(self, cls): - msg = "Class %s is deprecated" % cls.__name__ - if self.extra: - msg += "; %s" % self.extra - - # FIXME: we should probably reset __new__ for full generality - init = cls.__init__ - - def wrapped(*args, **kwargs): - warnings.warn(msg, category=DeprecationWarning) - return init(*args, **kwargs) - - cls.__init__ = wrapped - - wrapped.__name__ = '__init__' - wrapped.__doc__ = self._update_doc(init.__doc__) - wrapped.deprecated_original = init - - return cls - - def _decorate_fun(self, fun): - """Decorate function fun""" - - msg = "Function %s is deprecated" % fun.__name__ - if self.extra: - msg += "; %s" % self.extra - - def wrapped(*args, **kwargs): - warnings.warn(msg, category=DeprecationWarning) - return fun(*args, **kwargs) - - wrapped.__name__ = fun.__name__ - wrapped.__dict__ = fun.__dict__ - wrapped.__doc__ = self._update_doc(fun.__doc__) - - return wrapped - - def _update_doc(self, olddoc): - newdoc = "DEPRECATED" - if self.extra: - newdoc = "%s: %s" % (newdoc, self.extra) - if olddoc: - newdoc = "%s\n\n%s" % (newdoc, olddoc) - return newdoc - - -def _is_deprecated(func): - """Helper to check if func is wraped by our deprecated decorator""" - if sys.version_info < (3, 5): - raise NotImplementedError("This is only available for python3.5 " - "or above") - closures = getattr(func, '__closure__', []) - if closures is None: - closures = [] - is_deprecated = ('deprecated' in ''.join([c.cell_contents - for c in closures - if isinstance(c.cell_contents, str)])) - return is_deprecated diff --git a/ot/utils.py b/ot/utils.py index 2b2f8b3..29ad536 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -13,6 +13,9 @@ import time import numpy as np from scipy.spatial.distance import cdist +import sys +import warnings + __time_tic_toc = time.time() @@ -163,3 +166,219 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()): [p.join() for p in proc] return [x for i, x in sorted(res)] + + +class deprecated(object): + """Decorator to mark a function or class as deprecated. + + deprecated class from scikit-learn package + https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py + Issue a warning when the function is called/the class is instantiated and + adds a warning to the docstring. + The optional extra argument will be appended to the deprecation message + and the docstring. Note: to use this with the default value for extra, put + in an empty of parentheses: + >>> from ot.deprecation import deprecated + >>> @deprecated() + ... def some_function(): pass + + Parameters + ---------- + extra : string + to be added to the deprecation messages + """ + + # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, + # but with many changes. + + def __init__(self, extra=''): + self.extra = extra + + def __call__(self, obj): + """Call method + Parameters + ---------- + obj : object + """ + if isinstance(obj, type): + return self._decorate_class(obj) + else: + return self._decorate_fun(obj) + + def _decorate_class(self, cls): + msg = "Class %s is deprecated" % cls.__name__ + if self.extra: + msg += "; %s" % self.extra + + # FIXME: we should probably reset __new__ for full generality + init = cls.__init__ + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return init(*args, **kwargs) + + cls.__init__ = wrapped + + wrapped.__name__ = '__init__' + wrapped.__doc__ = self._update_doc(init.__doc__) + wrapped.deprecated_original = init + + return cls + + def _decorate_fun(self, fun): + """Decorate function fun""" + + msg = "Function %s is deprecated" % fun.__name__ + if self.extra: + msg += "; %s" % self.extra + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return fun(*args, **kwargs) + + wrapped.__name__ = fun.__name__ + wrapped.__dict__ = fun.__dict__ + wrapped.__doc__ = self._update_doc(fun.__doc__) + + return wrapped + + def _update_doc(self, olddoc): + newdoc = "DEPRECATED" + if self.extra: + newdoc = "%s: %s" % (newdoc, self.extra) + if olddoc: + newdoc = "%s\n\n%s" % (newdoc, olddoc) + return newdoc + + +def _is_deprecated(func): + """Helper to check if func is wraped by our deprecated decorator""" + if sys.version_info < (3, 5): + raise NotImplementedError("This is only available for python3.5 " + "or above") + closures = getattr(func, '__closure__', []) + if closures is None: + closures = [] + is_deprecated = ('deprecated' in ''.join([c.cell_contents + for c in closures + if isinstance(c.cell_contents, str)])) + return is_deprecated + + +class BaseEstimator(object): + """Base class for most objects in POT + adapted from sklearn BaseEstimator class + + Notes + ----- + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). + """ + + @classmethod + def _get_param_names(cls): + """Get parameter names for the estimator""" + try: + from inspect import signature + except ImportError: + from .externals.funcsigs import signature + # fetch the constructor or the original constructor before + # deprecation wrapping if any + init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + if init is object.__init__: + # No explicit constructor to introspect + return [] + + # introspect the constructor arguments to find the model parameters + # to represent + init_signature = signature(init) + # Consider the constructor parameters excluding 'self' + parameters = [p for p in init_signature.parameters.values() + if p.name != 'self' and p.kind != p.VAR_KEYWORD] + for p in parameters: + if p.kind == p.VAR_POSITIONAL: + raise RuntimeError("POT estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." + % (cls, init_signature)) + # Extract and sort argument names excluding 'self' + return sorted([p.name for p in parameters]) + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Parameters + ---------- + deep : boolean, optional + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : mapping of string to any + Parameter names mapped to their values. + """ + out = dict() + for key in self._get_param_names(): + # We need deprecation warnings to always be on in order to + # catch deprecated param values. + # This is set in utils/__init__.py but it gets overwritten + # when running under python3 somehow. + warnings.simplefilter("always", DeprecationWarning) + try: + with warnings.catch_warnings(record=True) as w: + value = getattr(self, key, None) + if len(w) and w[0].category == DeprecationWarning: + # if the parameter is deprecated, don't show it + continue + finally: + warnings.filters.pop(0) + + # XXX: should we rather test if instance of estimator? + if deep and hasattr(value, 'get_params'): + deep_items = value.get_params().items() + out.update((key + '__' + k, val) for k, val in deep_items) + out[key] = value + return out + + def set_params(self, **params): + """Set the parameters of this estimator. + + The method works on simple estimators as well as on nested objects + (such as pipelines). The latter have parameters of the form + ``__`` so that it's possible to update each + component of a nested object. + + Returns + ------- + self + """ + if not params: + # Simple optimisation to gain speed (inspect is slow) + return self + valid_params = self.get_params(deep=True) + # for key, value in iteritems(params): + for key, value in params.items(): + split = key.split('__', 1) + if len(split) > 1: + # nested objects case + name, sub_name = split + if name not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (name, self)) + sub_object = valid_params[name] + sub_object.set_params(**{sub_name: value}) + else: + # simple objects case + if key not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (key, self.__class__.__name__)) + setattr(self, key, value) + return self -- cgit v1.2.3