From 55840f6bccadd79caf722d86f06da857e3045453 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Mon, 28 Aug 2017 10:43:31 +0200 Subject: move no da objects into utils.py --- ot/da.py | 136 +-------------------------------------------------------------- 1 file changed, 1 insertion(+), 135 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index 5a34979..8c62669 100644 --- a/ot/da.py +++ b/ot/da.py @@ -10,14 +10,13 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np -import warnings from .bregman import sinkhorn from .lp import emd from .utils import unif, dist, kernel +from .utils import deprecated, BaseEstimator from .optim import cg from .optim import gcg -from .deprecation import deprecated def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -936,139 +935,6 @@ class OTDA_mapping_kernel(OTDA_mapping_linear): print("Warning, model not fitted yet, returning None") return None -############################################################################## -# proposal -############################################################################## - - -# adapted from sklearn - -class BaseEstimator(object): - """Base class for all estimators in scikit-learn - Notes - ----- - All estimators should specify all the parameters that can be set - at the class level in their ``__init__`` as explicit keyword - arguments (no ``*args`` or ``**kwargs``). - """ - - @classmethod - def _get_param_names(cls): - """Get parameter names for the estimator""" - try: - from inspect import signature - except ImportError: - from .externals.funcsigs import signature - # fetch the constructor or the original constructor before - # deprecation wrapping if any - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) - if init is object.__init__: - # No explicit constructor to introspect - return [] - - # introspect the constructor arguments to find the model parameters - # to represent - init_signature = signature(init) - # Consider the constructor parameters excluding 'self' - parameters = [p for p in init_signature.parameters.values() - if p.name != 'self' and p.kind != p.VAR_KEYWORD] - for p in parameters: - if p.kind == p.VAR_POSITIONAL: - raise RuntimeError("scikit-learn estimators should always " - "specify their parameters in the signature" - " of their __init__ (no varargs)." - " %s with constructor %s doesn't " - " follow this convention." - % (cls, init_signature)) - # Extract and sort argument names excluding 'self' - return sorted([p.name for p in parameters]) - - def get_params(self, deep=True): - """Get parameters for this estimator. - - Parameters - ---------- - deep : boolean, optional - If True, will return the parameters for this estimator and - contained subobjects that are estimators. - - Returns - ------- - params : mapping of string to any - Parameter names mapped to their values. - """ - out = dict() - for key in self._get_param_names(): - # We need deprecation warnings to always be on in order to - # catch deprecated param values. - # This is set in utils/__init__.py but it gets overwritten - # when running under python3 somehow. - warnings.simplefilter("always", DeprecationWarning) - try: - with warnings.catch_warnings(record=True) as w: - value = getattr(self, key, None) - if len(w) and w[0].category == DeprecationWarning: - # if the parameter is deprecated, don't show it - continue - finally: - warnings.filters.pop(0) - - # XXX: should we rather test if instance of estimator? - if deep and hasattr(value, 'get_params'): - deep_items = value.get_params().items() - out.update((key + '__' + k, val) for k, val in deep_items) - out[key] = value - return out - - def set_params(self, **params): - """Set the parameters of this estimator. - - The method works on simple estimators as well as on nested objects - (such as pipelines). The latter have parameters of the form - ``__`` so that it's possible to update each - component of a nested object. - - Returns - ------- - self - """ - if not params: - # Simple optimisation to gain speed (inspect is slow) - return self - valid_params = self.get_params(deep=True) - # for key, value in iteritems(params): - for key, value in params.items(): - split = key.split('__', 1) - if len(split) > 1: - # nested objects case - name, sub_name = split - if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) - sub_object = valid_params[name] - sub_object.set_params(**{sub_name: value}) - else: - # simple objects case - if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) - setattr(self, key, value) - return self - - def __repr__(self): - from sklearn.base import _pprint - class_name = self.__class__.__name__ - return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False), - offset=len(class_name),),) - - # __getstate__ and __setstate__ are omitted because they only contain - # conditionals that are not satisfied by our objects (e.g., - # ``if type(self).__module__.startswith('sklearn.')``. - def distribution_estimation_uniform(X): """estimates a uniform distribution from an array of samples X -- cgit v1.2.3