diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-08-28 10:43:31 +0200 |
---|---|---|
committer | Slasnista <stan.chambon@gmail.com> | 2017-08-28 10:43:31 +0200 |
commit | 55840f6bccadd79caf722d86f06da857e3045453 (patch) | |
tree | 833e706b9109579b68a94956a05d20e6fa773e49 /ot/da.py | |
parent | 892d7ce10912de3d07692b5083578932a32ab61f (diff) |
move no da objects into utils.py
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 136 |
1 files changed, 1 insertions, 135 deletions
@@ -10,14 +10,13 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np -import warnings from .bregman import sinkhorn from .lp import emd from .utils import unif, dist, kernel +from .utils import deprecated, BaseEstimator from .optim import cg from .optim import gcg -from .deprecation import deprecated def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -936,139 +935,6 @@ class OTDA_mapping_kernel(OTDA_mapping_linear): print("Warning, model not fitted yet, returning None") return None -############################################################################## -# proposal -############################################################################## - - -# adapted from sklearn - -class BaseEstimator(object): - """Base class for all estimators in scikit-learn - Notes - ----- - All estimators should specify all the parameters that can be set - at the class level in their ``__init__`` as explicit keyword - arguments (no ``*args`` or ``**kwargs``). - """ - - @classmethod - def _get_param_names(cls): - """Get parameter names for the estimator""" - try: - from inspect import signature - except ImportError: - from .externals.funcsigs import signature - # fetch the constructor or the original constructor before - # deprecation wrapping if any - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) - if init is object.__init__: - # No explicit constructor to introspect - return [] - - # introspect the constructor arguments to find the model parameters - # to represent - init_signature = signature(init) - # Consider the constructor parameters excluding 'self' - parameters = [p for p in init_signature.parameters.values() - if p.name != 'self' and p.kind != p.VAR_KEYWORD] - for p in parameters: - if p.kind == p.VAR_POSITIONAL: - raise RuntimeError("scikit-learn estimators should always " - "specify their parameters in the signature" - " of their __init__ (no varargs)." - " %s with constructor %s doesn't " - " follow this convention." - % (cls, init_signature)) - # Extract and sort argument names excluding 'self' - return sorted([p.name for p in parameters]) - - def get_params(self, deep=True): - """Get parameters for this estimator. - - Parameters - ---------- - deep : boolean, optional - If True, will return the parameters for this estimator and - contained subobjects that are estimators. - - Returns - ------- - params : mapping of string to any - Parameter names mapped to their values. - """ - out = dict() - for key in self._get_param_names(): - # We need deprecation warnings to always be on in order to - # catch deprecated param values. - # This is set in utils/__init__.py but it gets overwritten - # when running under python3 somehow. - warnings.simplefilter("always", DeprecationWarning) - try: - with warnings.catch_warnings(record=True) as w: - value = getattr(self, key, None) - if len(w) and w[0].category == DeprecationWarning: - # if the parameter is deprecated, don't show it - continue - finally: - warnings.filters.pop(0) - - # XXX: should we rather test if instance of estimator? - if deep and hasattr(value, 'get_params'): - deep_items = value.get_params().items() - out.update((key + '__' + k, val) for k, val in deep_items) - out[key] = value - return out - - def set_params(self, **params): - """Set the parameters of this estimator. - - The method works on simple estimators as well as on nested objects - (such as pipelines). The latter have parameters of the form - ``<component>__<parameter>`` so that it's possible to update each - component of a nested object. - - Returns - ------- - self - """ - if not params: - # Simple optimisation to gain speed (inspect is slow) - return self - valid_params = self.get_params(deep=True) - # for key, value in iteritems(params): - for key, value in params.items(): - split = key.split('__', 1) - if len(split) > 1: - # nested objects case - name, sub_name = split - if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) - sub_object = valid_params[name] - sub_object.set_params(**{sub_name: value}) - else: - # simple objects case - if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) - setattr(self, key, value) - return self - - def __repr__(self): - from sklearn.base import _pprint - class_name = self.__class__.__name__ - return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False), - offset=len(class_name),),) - - # __getstate__ and __setstate__ are omitted because they only contain - # conditionals that are not satisfied by our objects (e.g., - # ``if type(self).__module__.startswith('sklearn.')``. - def distribution_estimation_uniform(X): """estimates a uniform distribution from an array of samples X |