From 5ab50354e60ed94d9d799927fd4b680fb8447304 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Mon, 31 Jul 2017 10:54:28 +0200 Subject: own BaseEstimator class written + rflamary comments addressed --- ot/da.py | 199 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 172 insertions(+), 27 deletions(-) diff --git a/ot/da.py b/ot/da.py index 828efc2..d30c821 100644 --- a/ot/da.py +++ b/ot/da.py @@ -921,21 +921,153 @@ class OTDA_mapping_kernel(OTDA_mapping_linear): # proposal ############################################################################## -from sklearn.base import BaseEstimator -from sklearn.metrics import pairwise_distances +# from sklearn.base import BaseEstimator +# from sklearn.metrics import pairwise_distances + +############################################################################## +# adapted from scikit-learn + +import warnings +# from .externals.six import string_types, iteritems -""" -- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?) -- reg_e: is the entropic reg parameter -- reg_cl: is the second reg parameter -- gamma_: is the optimal coupling -- mapping barycentric for the moment - -Questions: -- Cost matrix estimation: from sklearn or from internal function ? -- distribution estimation ? Look at Nathalie's approach -- should everything been done into the fit from BaseTransport ? -""" + +class BaseEstimator(object): + """Base class for all estimators in scikit-learn + Notes + ----- + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). + """ + + @classmethod + def _get_param_names(cls): + """Get parameter names for the estimator""" + try: + from inspect import signature + except ImportError: + from .externals.funcsigs import signature + # fetch the constructor or the original constructor before + # deprecation wrapping if any + init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + if init is object.__init__: + # No explicit constructor to introspect + return [] + + # introspect the constructor arguments to find the model parameters + # to represent + init_signature = signature(init) + # Consider the constructor parameters excluding 'self' + parameters = [p for p in init_signature.parameters.values() + if p.name != 'self' and p.kind != p.VAR_KEYWORD] + for p in parameters: + if p.kind == p.VAR_POSITIONAL: + raise RuntimeError("scikit-learn estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." + % (cls, init_signature)) + # Extract and sort argument names excluding 'self' + return sorted([p.name for p in parameters]) + + def get_params(self, deep=True): + """Get parameters for this estimator. + Parameters + ---------- + deep : boolean, optional + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + Returns + ------- + params : mapping of string to any + Parameter names mapped to their values. + """ + out = dict() + for key in self._get_param_names(): + # We need deprecation warnings to always be on in order to + # catch deprecated param values. + # This is set in utils/__init__.py but it gets overwritten + # when running under python3 somehow. + warnings.simplefilter("always", DeprecationWarning) + try: + with warnings.catch_warnings(record=True) as w: + value = getattr(self, key, None) + if len(w) and w[0].category == DeprecationWarning: + # if the parameter is deprecated, don't show it + continue + finally: + warnings.filters.pop(0) + + # XXX: should we rather test if instance of estimator? + if deep and hasattr(value, 'get_params'): + deep_items = value.get_params().items() + out.update((key + '__' + k, val) for k, val in deep_items) + out[key] = value + return out + + def set_params(self, **params): + """Set the parameters of this estimator. + The method works on simple estimators as well as on nested objects + (such as pipelines). The latter have parameters of the form + ``__`` so that it's possible to update each + component of a nested object. + Returns + ------- + self + """ + if not params: + # Simple optimisation to gain speed (inspect is slow) + return self + valid_params = self.get_params(deep=True) + # for key, value in iteritems(params): + for key, value in params.items(): + split = key.split('__', 1) + if len(split) > 1: + # nested objects case + name, sub_name = split + if name not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (name, self)) + sub_object = valid_params[name] + sub_object.set_params(**{sub_name: value}) + else: + # simple objects case + if key not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (key, self.__class__.__name__)) + setattr(self, key, value) + return self + + def __repr__(self): + from sklearn.base import _pprint + class_name = self.__class__.__name__ + return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False), + offset=len(class_name),),) + + # __getstate__ and __setstate__ are omitted because they only contain + # conditionals that are not satisfied by our objects (e.g., + # ``if type(self).__module__.startswith('sklearn.')``. + + +def distribution_estimation_uniform(X): + """estimates a uniform distribution from an array of samples X + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The array of samples + Returns + ------- + mu : array-like, shape = [n_samples,] + The uniform distribution estimated from X + """ + + return np.ones(X.shape[0]) / float(X.shape[0]) class BaseTransport(BaseEstimator): @@ -960,18 +1092,19 @@ class BaseTransport(BaseEstimator): """ # pairwise distance - Cost = pairwise_distances(Xs, Xt, metric=self.metric) + Cost = dist(Xs, Xt, metric=self.metric) if self.mode == "semisupervised": print("TODO: modify cost matrix accordingly") pass # distribution estimation - if self.distribution == "uniform": - mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0]) - mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0]) - else: - print("TODO: implement kernelized approach") + mu_s = self.distribution_estimation(Xs) + mu_t = self.distribution_estimation(Xt) + + # store arrays of samples + self.Xs = Xs + self.Xt = Xt # coupling estimation if self.method == "sinkhorn": @@ -1024,14 +1157,19 @@ class BaseTransport(BaseEstimator): The transport source samples. """ - if self.mapping == "barycentric": + # TODO: check whether Xs is new or not + if self.Xs == Xs: + # perform standard barycentric mapping transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None] # set nans to 0 transp[~ np.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, Xt) + transp_Xs = np.dot(transp, self.Xt) + else: + # perform out of sample mapping + print("out of sample mapping not yet implemented") return transp_Xs @@ -1053,16 +1191,19 @@ class BaseTransport(BaseEstimator): The transported target samples. """ - if self.mapping == "barycentric": + # TODO: check whether Xt is new or not + if self.Xt == Xt: + # perform standard barycentric mapping transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None] # set nans to 0 transp_[~ np.isfinite(transp_)] = 0 # compute transported samples - transp_Xt = np.dot(transp_, Xs) + transp_Xt = np.dot(transp_, self.Xs) else: - print("mapping not yet implemented") + # perform out of sample mapping + print("out of sample mapping not yet implemented") return transp_Xt @@ -1114,7 +1255,10 @@ class SinkhornTransport(BaseTransport): def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000, tol=10e-9, verbose=False, log=False, mapping="barycentric", - metric="sqeuclidean", distribution="uniform"): + metric="sqeuclidean", + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans'): + self.reg_e = reg_e self.mode = mode self.max_iter = max_iter @@ -1123,8 +1267,9 @@ class SinkhornTransport(BaseTransport): self.log = log self.mapping = mapping self.metric = metric - self.distribution = distribution + self.distribution_estimation = distribution_estimation self.method = "sinkhorn" + self.out_of_sample_map = out_of_sample_map def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples -- cgit v1.2.3