summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-28 10:43:31 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-28 10:43:31 +0200
commit55840f6bccadd79caf722d86f06da857e3045453 (patch)
tree833e706b9109579b68a94956a05d20e6fa773e49 /ot/da.py
parent892d7ce10912de3d07692b5083578932a32ab61f (diff)
move no da objects into utils.py
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py136
1 files changed, 1 insertions, 135 deletions
diff --git a/ot/da.py b/ot/da.py
index 5a34979..8c62669 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -10,14 +10,13 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
-import warnings
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel
+from .utils import deprecated, BaseEstimator
from .optim import cg
from .optim import gcg
-from .deprecation import deprecated
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -936,139 +935,6 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
print("Warning, model not fitted yet, returning None")
return None
-##############################################################################
-# proposal
-##############################################################################
-
-
-# adapted from sklearn
-
-class BaseEstimator(object):
- """Base class for all estimators in scikit-learn
- Notes
- -----
- All estimators should specify all the parameters that can be set
- at the class level in their ``__init__`` as explicit keyword
- arguments (no ``*args`` or ``**kwargs``).
- """
-
- @classmethod
- def _get_param_names(cls):
- """Get parameter names for the estimator"""
- try:
- from inspect import signature
- except ImportError:
- from .externals.funcsigs import signature
- # fetch the constructor or the original constructor before
- # deprecation wrapping if any
- init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
- if init is object.__init__:
- # No explicit constructor to introspect
- return []
-
- # introspect the constructor arguments to find the model parameters
- # to represent
- init_signature = signature(init)
- # Consider the constructor parameters excluding 'self'
- parameters = [p for p in init_signature.parameters.values()
- if p.name != 'self' and p.kind != p.VAR_KEYWORD]
- for p in parameters:
- if p.kind == p.VAR_POSITIONAL:
- raise RuntimeError("scikit-learn estimators should always "
- "specify their parameters in the signature"
- " of their __init__ (no varargs)."
- " %s with constructor %s doesn't "
- " follow this convention."
- % (cls, init_signature))
- # Extract and sort argument names excluding 'self'
- return sorted([p.name for p in parameters])
-
- def get_params(self, deep=True):
- """Get parameters for this estimator.
-
- Parameters
- ----------
- deep : boolean, optional
- If True, will return the parameters for this estimator and
- contained subobjects that are estimators.
-
- Returns
- -------
- params : mapping of string to any
- Parameter names mapped to their values.
- """
- out = dict()
- for key in self._get_param_names():
- # We need deprecation warnings to always be on in order to
- # catch deprecated param values.
- # This is set in utils/__init__.py but it gets overwritten
- # when running under python3 somehow.
- warnings.simplefilter("always", DeprecationWarning)
- try:
- with warnings.catch_warnings(record=True) as w:
- value = getattr(self, key, None)
- if len(w) and w[0].category == DeprecationWarning:
- # if the parameter is deprecated, don't show it
- continue
- finally:
- warnings.filters.pop(0)
-
- # XXX: should we rather test if instance of estimator?
- if deep and hasattr(value, 'get_params'):
- deep_items = value.get_params().items()
- out.update((key + '__' + k, val) for k, val in deep_items)
- out[key] = value
- return out
-
- def set_params(self, **params):
- """Set the parameters of this estimator.
-
- The method works on simple estimators as well as on nested objects
- (such as pipelines). The latter have parameters of the form
- ``<component>__<parameter>`` so that it's possible to update each
- component of a nested object.
-
- Returns
- -------
- self
- """
- if not params:
- # Simple optimisation to gain speed (inspect is slow)
- return self
- valid_params = self.get_params(deep=True)
- # for key, value in iteritems(params):
- for key, value in params.items():
- split = key.split('__', 1)
- if len(split) > 1:
- # nested objects case
- name, sub_name = split
- if name not in valid_params:
- raise ValueError('Invalid parameter %s for estimator %s. '
- 'Check the list of available parameters '
- 'with `estimator.get_params().keys()`.' %
- (name, self))
- sub_object = valid_params[name]
- sub_object.set_params(**{sub_name: value})
- else:
- # simple objects case
- if key not in valid_params:
- raise ValueError('Invalid parameter %s for estimator %s. '
- 'Check the list of available parameters '
- 'with `estimator.get_params().keys()`.' %
- (key, self.__class__.__name__))
- setattr(self, key, value)
- return self
-
- def __repr__(self):
- from sklearn.base import _pprint
- class_name = self.__class__.__name__
- return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
- offset=len(class_name),),)
-
- # __getstate__ and __setstate__ are omitted because they only contain
- # conditionals that are not satisfied by our objects (e.g.,
- # ``if type(self).__module__.startswith('sklearn.')``.
-
def distribution_estimation_uniform(X):
"""estimates a uniform distribution from an array of samples X