summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-28 10:43:31 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-28 10:43:31 +0200
commit55840f6bccadd79caf722d86f06da857e3045453 (patch)
tree833e706b9109579b68a94956a05d20e6fa773e49
parent892d7ce10912de3d07692b5083578932a32ab61f (diff)
move no da objects into utils.py
-rw-r--r--README.md2
-rw-r--r--ot/da.py136
-rw-r--r--ot/deprecation.py103
-rw-r--r--ot/utils.py219
4 files changed, 222 insertions, 238 deletions
diff --git a/README.md b/README.md
index 7a65106..27b4643 100644
--- a/README.md
+++ b/README.md
@@ -136,6 +136,8 @@ The contributors to this library are:
* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/)
* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation)
* [Léo Gautheron](https://github.com/aje) (GPU implementation)
+* [Nathalie Gayraud]()
+* [Stanislas Chambon](https://slasnista.github.io/)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
diff --git a/ot/da.py b/ot/da.py
index 5a34979..8c62669 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -10,14 +10,13 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
-import warnings
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel
+from .utils import deprecated, BaseEstimator
from .optim import cg
from .optim import gcg
-from .deprecation import deprecated
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -936,139 +935,6 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
print("Warning, model not fitted yet, returning None")
return None
-##############################################################################
-# proposal
-##############################################################################
-
-
-# adapted from sklearn
-
-class BaseEstimator(object):
- """Base class for all estimators in scikit-learn
- Notes
- -----
- All estimators should specify all the parameters that can be set
- at the class level in their ``__init__`` as explicit keyword
- arguments (no ``*args`` or ``**kwargs``).
- """
-
- @classmethod
- def _get_param_names(cls):
- """Get parameter names for the estimator"""
- try:
- from inspect import signature
- except ImportError:
- from .externals.funcsigs import signature
- # fetch the constructor or the original constructor before
- # deprecation wrapping if any
- init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
- if init is object.__init__:
- # No explicit constructor to introspect
- return []
-
- # introspect the constructor arguments to find the model parameters
- # to represent
- init_signature = signature(init)
- # Consider the constructor parameters excluding 'self'
- parameters = [p for p in init_signature.parameters.values()
- if p.name != 'self' and p.kind != p.VAR_KEYWORD]
- for p in parameters:
- if p.kind == p.VAR_POSITIONAL:
- raise RuntimeError("scikit-learn estimators should always "
- "specify their parameters in the signature"
- " of their __init__ (no varargs)."
- " %s with constructor %s doesn't "
- " follow this convention."
- % (cls, init_signature))
- # Extract and sort argument names excluding 'self'
- return sorted([p.name for p in parameters])
-
- def get_params(self, deep=True):
- """Get parameters for this estimator.
-
- Parameters
- ----------
- deep : boolean, optional
- If True, will return the parameters for this estimator and
- contained subobjects that are estimators.
-
- Returns
- -------
- params : mapping of string to any
- Parameter names mapped to their values.
- """
- out = dict()
- for key in self._get_param_names():
- # We need deprecation warnings to always be on in order to
- # catch deprecated param values.
- # This is set in utils/__init__.py but it gets overwritten
- # when running under python3 somehow.
- warnings.simplefilter("always", DeprecationWarning)
- try:
- with warnings.catch_warnings(record=True) as w:
- value = getattr(self, key, None)
- if len(w) and w[0].category == DeprecationWarning:
- # if the parameter is deprecated, don't show it
- continue
- finally:
- warnings.filters.pop(0)
-
- # XXX: should we rather test if instance of estimator?
- if deep and hasattr(value, 'get_params'):
- deep_items = value.get_params().items()
- out.update((key + '__' + k, val) for k, val in deep_items)
- out[key] = value
- return out
-
- def set_params(self, **params):
- """Set the parameters of this estimator.
-
- The method works on simple estimators as well as on nested objects
- (such as pipelines). The latter have parameters of the form
- ``<component>__<parameter>`` so that it's possible to update each
- component of a nested object.
-
- Returns
- -------
- self
- """
- if not params:
- # Simple optimisation to gain speed (inspect is slow)
- return self
- valid_params = self.get_params(deep=True)
- # for key, value in iteritems(params):
- for key, value in params.items():
- split = key.split('__', 1)
- if len(split) > 1:
- # nested objects case
- name, sub_name = split
- if name not in valid_params:
- raise ValueError('Invalid parameter %s for estimator %s. '
- 'Check the list of available parameters '
- 'with `estimator.get_params().keys()`.' %
- (name, self))
- sub_object = valid_params[name]
- sub_object.set_params(**{sub_name: value})
- else:
- # simple objects case
- if key not in valid_params:
- raise ValueError('Invalid parameter %s for estimator %s. '
- 'Check the list of available parameters '
- 'with `estimator.get_params().keys()`.' %
- (key, self.__class__.__name__))
- setattr(self, key, value)
- return self
-
- def __repr__(self):
- from sklearn.base import _pprint
- class_name = self.__class__.__name__
- return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
- offset=len(class_name),),)
-
- # __getstate__ and __setstate__ are omitted because they only contain
- # conditionals that are not satisfied by our objects (e.g.,
- # ``if type(self).__module__.startswith('sklearn.')``.
-
def distribution_estimation_uniform(X):
"""estimates a uniform distribution from an array of samples X
diff --git a/ot/deprecation.py b/ot/deprecation.py
deleted file mode 100644
index 2b16427..0000000
--- a/ot/deprecation.py
+++ /dev/null
@@ -1,103 +0,0 @@
-"""
- deprecated class from scikit-learn package
- https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
-"""
-
-import sys
-import warnings
-
-__all__ = ["deprecated", ]
-
-
-class deprecated(object):
- """Decorator to mark a function or class as deprecated.
- Issue a warning when the function is called/the class is instantiated and
- adds a warning to the docstring.
- The optional extra argument will be appended to the deprecation message
- and the docstring. Note: to use this with the default value for extra, put
- in an empty of parentheses:
- >>> from ot.deprecation import deprecated
- >>> @deprecated()
- ... def some_function(): pass
-
- Parameters
- ----------
- extra : string
- to be added to the deprecation messages
- """
-
- # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
- # but with many changes.
-
- def __init__(self, extra=''):
- self.extra = extra
-
- def __call__(self, obj):
- """Call method
- Parameters
- ----------
- obj : object
- """
- if isinstance(obj, type):
- return self._decorate_class(obj)
- else:
- return self._decorate_fun(obj)
-
- def _decorate_class(self, cls):
- msg = "Class %s is deprecated" % cls.__name__
- if self.extra:
- msg += "; %s" % self.extra
-
- # FIXME: we should probably reset __new__ for full generality
- init = cls.__init__
-
- def wrapped(*args, **kwargs):
- warnings.warn(msg, category=DeprecationWarning)
- return init(*args, **kwargs)
-
- cls.__init__ = wrapped
-
- wrapped.__name__ = '__init__'
- wrapped.__doc__ = self._update_doc(init.__doc__)
- wrapped.deprecated_original = init
-
- return cls
-
- def _decorate_fun(self, fun):
- """Decorate function fun"""
-
- msg = "Function %s is deprecated" % fun.__name__
- if self.extra:
- msg += "; %s" % self.extra
-
- def wrapped(*args, **kwargs):
- warnings.warn(msg, category=DeprecationWarning)
- return fun(*args, **kwargs)
-
- wrapped.__name__ = fun.__name__
- wrapped.__dict__ = fun.__dict__
- wrapped.__doc__ = self._update_doc(fun.__doc__)
-
- return wrapped
-
- def _update_doc(self, olddoc):
- newdoc = "DEPRECATED"
- if self.extra:
- newdoc = "%s: %s" % (newdoc, self.extra)
- if olddoc:
- newdoc = "%s\n\n%s" % (newdoc, olddoc)
- return newdoc
-
-
-def _is_deprecated(func):
- """Helper to check if func is wraped by our deprecated decorator"""
- if sys.version_info < (3, 5):
- raise NotImplementedError("This is only available for python3.5 "
- "or above")
- closures = getattr(func, '__closure__', [])
- if closures is None:
- closures = []
- is_deprecated = ('deprecated' in ''.join([c.cell_contents
- for c in closures
- if isinstance(c.cell_contents, str)]))
- return is_deprecated
diff --git a/ot/utils.py b/ot/utils.py
index 2b2f8b3..29ad536 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -13,6 +13,9 @@ import time
import numpy as np
from scipy.spatial.distance import cdist
+import sys
+import warnings
+
__time_tic_toc = time.time()
@@ -163,3 +166,219 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
[p.join() for p in proc]
return [x for i, x in sorted(res)]
+
+
+class deprecated(object):
+ """Decorator to mark a function or class as deprecated.
+
+ deprecated class from scikit-learn package
+ https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
+ Issue a warning when the function is called/the class is instantiated and
+ adds a warning to the docstring.
+ The optional extra argument will be appended to the deprecation message
+ and the docstring. Note: to use this with the default value for extra, put
+ in an empty of parentheses:
+ >>> from ot.deprecation import deprecated
+ >>> @deprecated()
+ ... def some_function(): pass
+
+ Parameters
+ ----------
+ extra : string
+ to be added to the deprecation messages
+ """
+
+ # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
+ # but with many changes.
+
+ def __init__(self, extra=''):
+ self.extra = extra
+
+ def __call__(self, obj):
+ """Call method
+ Parameters
+ ----------
+ obj : object
+ """
+ if isinstance(obj, type):
+ return self._decorate_class(obj)
+ else:
+ return self._decorate_fun(obj)
+
+ def _decorate_class(self, cls):
+ msg = "Class %s is deprecated" % cls.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ # FIXME: we should probably reset __new__ for full generality
+ init = cls.__init__
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return init(*args, **kwargs)
+
+ cls.__init__ = wrapped
+
+ wrapped.__name__ = '__init__'
+ wrapped.__doc__ = self._update_doc(init.__doc__)
+ wrapped.deprecated_original = init
+
+ return cls
+
+ def _decorate_fun(self, fun):
+ """Decorate function fun"""
+
+ msg = "Function %s is deprecated" % fun.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return fun(*args, **kwargs)
+
+ wrapped.__name__ = fun.__name__
+ wrapped.__dict__ = fun.__dict__
+ wrapped.__doc__ = self._update_doc(fun.__doc__)
+
+ return wrapped
+
+ def _update_doc(self, olddoc):
+ newdoc = "DEPRECATED"
+ if self.extra:
+ newdoc = "%s: %s" % (newdoc, self.extra)
+ if olddoc:
+ newdoc = "%s\n\n%s" % (newdoc, olddoc)
+ return newdoc
+
+
+def _is_deprecated(func):
+ """Helper to check if func is wraped by our deprecated decorator"""
+ if sys.version_info < (3, 5):
+ raise NotImplementedError("This is only available for python3.5 "
+ "or above")
+ closures = getattr(func, '__closure__', [])
+ if closures is None:
+ closures = []
+ is_deprecated = ('deprecated' in ''.join([c.cell_contents
+ for c in closures
+ if isinstance(c.cell_contents, str)]))
+ return is_deprecated
+
+
+class BaseEstimator(object):
+ """Base class for most objects in POT
+ adapted from sklearn BaseEstimator class
+
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+ """
+
+ @classmethod
+ def _get_param_names(cls):
+ """Get parameter names for the estimator"""
+ try:
+ from inspect import signature
+ except ImportError:
+ from .externals.funcsigs import signature
+ # fetch the constructor or the original constructor before
+ # deprecation wrapping if any
+ init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
+ if init is object.__init__:
+ # No explicit constructor to introspect
+ return []
+
+ # introspect the constructor arguments to find the model parameters
+ # to represent
+ init_signature = signature(init)
+ # Consider the constructor parameters excluding 'self'
+ parameters = [p for p in init_signature.parameters.values()
+ if p.name != 'self' and p.kind != p.VAR_KEYWORD]
+ for p in parameters:
+ if p.kind == p.VAR_POSITIONAL:
+ raise RuntimeError("POT estimators should always "
+ "specify their parameters in the signature"
+ " of their __init__ (no varargs)."
+ " %s with constructor %s doesn't "
+ " follow this convention."
+ % (cls, init_signature))
+ # Extract and sort argument names excluding 'self'
+ return sorted([p.name for p in parameters])
+
+ def get_params(self, deep=True):
+ """Get parameters for this estimator.
+
+ Parameters
+ ----------
+ deep : boolean, optional
+ If True, will return the parameters for this estimator and
+ contained subobjects that are estimators.
+
+ Returns
+ -------
+ params : mapping of string to any
+ Parameter names mapped to their values.
+ """
+ out = dict()
+ for key in self._get_param_names():
+ # We need deprecation warnings to always be on in order to
+ # catch deprecated param values.
+ # This is set in utils/__init__.py but it gets overwritten
+ # when running under python3 somehow.
+ warnings.simplefilter("always", DeprecationWarning)
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ value = getattr(self, key, None)
+ if len(w) and w[0].category == DeprecationWarning:
+ # if the parameter is deprecated, don't show it
+ continue
+ finally:
+ warnings.filters.pop(0)
+
+ # XXX: should we rather test if instance of estimator?
+ if deep and hasattr(value, 'get_params'):
+ deep_items = value.get_params().items()
+ out.update((key + '__' + k, val) for k, val in deep_items)
+ out[key] = value
+ return out
+
+ def set_params(self, **params):
+ """Set the parameters of this estimator.
+
+ The method works on simple estimators as well as on nested objects
+ (such as pipelines). The latter have parameters of the form
+ ``<component>__<parameter>`` so that it's possible to update each
+ component of a nested object.
+
+ Returns
+ -------
+ self
+ """
+ if not params:
+ # Simple optimisation to gain speed (inspect is slow)
+ return self
+ valid_params = self.get_params(deep=True)
+ # for key, value in iteritems(params):
+ for key, value in params.items():
+ split = key.split('__', 1)
+ if len(split) > 1:
+ # nested objects case
+ name, sub_name = split
+ if name not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (name, self))
+ sub_object = valid_params[name]
+ sub_object.set_params(**{sub_name: value})
+ else:
+ # simple objects case
+ if key not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (key, self.__class__.__name__))
+ setattr(self, key, value)
+ return self