summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-07-31 10:54:28 +0200
committerSlasnista <stan.chambon@gmail.com>2017-07-31 10:54:28 +0200
commitbd7c7d2534980d3105d060dd24a444433422134d (patch)
tree5297ae89c602fcf5b28b84c6a3e3619c2a40a3cd /ot/da.py
parentcd3397f852d8bf99e5de59069529b95d9ba00a05 (diff)
own BaseEstimator class written + rflamary comments addressed
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py199
1 files changed, 172 insertions, 27 deletions
diff --git a/ot/da.py b/ot/da.py
index 828efc2..d30c821 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -921,21 +921,153 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
# proposal
##############################################################################
-from sklearn.base import BaseEstimator
-from sklearn.metrics import pairwise_distances
+# from sklearn.base import BaseEstimator
+# from sklearn.metrics import pairwise_distances
+
+##############################################################################
+# adapted from scikit-learn
+
+import warnings
+# from .externals.six import string_types, iteritems
-"""
-- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
-- reg_e: is the entropic reg parameter
-- reg_cl: is the second reg parameter
-- gamma_: is the optimal coupling
-- mapping barycentric for the moment
-
-Questions:
-- Cost matrix estimation: from sklearn or from internal function ?
-- distribution estimation ? Look at Nathalie's approach
-- should everything been done into the fit from BaseTransport ?
-"""
+
+class BaseEstimator(object):
+ """Base class for all estimators in scikit-learn
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+ """
+
+ @classmethod
+ def _get_param_names(cls):
+ """Get parameter names for the estimator"""
+ try:
+ from inspect import signature
+ except ImportError:
+ from .externals.funcsigs import signature
+ # fetch the constructor or the original constructor before
+ # deprecation wrapping if any
+ init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
+ if init is object.__init__:
+ # No explicit constructor to introspect
+ return []
+
+ # introspect the constructor arguments to find the model parameters
+ # to represent
+ init_signature = signature(init)
+ # Consider the constructor parameters excluding 'self'
+ parameters = [p for p in init_signature.parameters.values()
+ if p.name != 'self' and p.kind != p.VAR_KEYWORD]
+ for p in parameters:
+ if p.kind == p.VAR_POSITIONAL:
+ raise RuntimeError("scikit-learn estimators should always "
+ "specify their parameters in the signature"
+ " of their __init__ (no varargs)."
+ " %s with constructor %s doesn't "
+ " follow this convention."
+ % (cls, init_signature))
+ # Extract and sort argument names excluding 'self'
+ return sorted([p.name for p in parameters])
+
+ def get_params(self, deep=True):
+ """Get parameters for this estimator.
+ Parameters
+ ----------
+ deep : boolean, optional
+ If True, will return the parameters for this estimator and
+ contained subobjects that are estimators.
+ Returns
+ -------
+ params : mapping of string to any
+ Parameter names mapped to their values.
+ """
+ out = dict()
+ for key in self._get_param_names():
+ # We need deprecation warnings to always be on in order to
+ # catch deprecated param values.
+ # This is set in utils/__init__.py but it gets overwritten
+ # when running under python3 somehow.
+ warnings.simplefilter("always", DeprecationWarning)
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ value = getattr(self, key, None)
+ if len(w) and w[0].category == DeprecationWarning:
+ # if the parameter is deprecated, don't show it
+ continue
+ finally:
+ warnings.filters.pop(0)
+
+ # XXX: should we rather test if instance of estimator?
+ if deep and hasattr(value, 'get_params'):
+ deep_items = value.get_params().items()
+ out.update((key + '__' + k, val) for k, val in deep_items)
+ out[key] = value
+ return out
+
+ def set_params(self, **params):
+ """Set the parameters of this estimator.
+ The method works on simple estimators as well as on nested objects
+ (such as pipelines). The latter have parameters of the form
+ ``<component>__<parameter>`` so that it's possible to update each
+ component of a nested object.
+ Returns
+ -------
+ self
+ """
+ if not params:
+ # Simple optimisation to gain speed (inspect is slow)
+ return self
+ valid_params = self.get_params(deep=True)
+ # for key, value in iteritems(params):
+ for key, value in params.items():
+ split = key.split('__', 1)
+ if len(split) > 1:
+ # nested objects case
+ name, sub_name = split
+ if name not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (name, self))
+ sub_object = valid_params[name]
+ sub_object.set_params(**{sub_name: value})
+ else:
+ # simple objects case
+ if key not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (key, self.__class__.__name__))
+ setattr(self, key, value)
+ return self
+
+ def __repr__(self):
+ from sklearn.base import _pprint
+ class_name = self.__class__.__name__
+ return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
+ offset=len(class_name),),)
+
+ # __getstate__ and __setstate__ are omitted because they only contain
+ # conditionals that are not satisfied by our objects (e.g.,
+ # ``if type(self).__module__.startswith('sklearn.')``.
+
+
+def distribution_estimation_uniform(X):
+ """estimates a uniform distribution from an array of samples X
+
+ Parameters
+ ----------
+ X : array-like of shape = [n_samples, n_features]
+ The array of samples
+ Returns
+ -------
+ mu : array-like, shape = [n_samples,]
+ The uniform distribution estimated from X
+ """
+
+ return np.ones(X.shape[0]) / float(X.shape[0])
class BaseTransport(BaseEstimator):
@@ -960,18 +1092,19 @@ class BaseTransport(BaseEstimator):
"""
# pairwise distance
- Cost = pairwise_distances(Xs, Xt, metric=self.metric)
+ Cost = dist(Xs, Xt, metric=self.metric)
if self.mode == "semisupervised":
print("TODO: modify cost matrix accordingly")
pass
# distribution estimation
- if self.distribution == "uniform":
- mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0])
- mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0])
- else:
- print("TODO: implement kernelized approach")
+ mu_s = self.distribution_estimation(Xs)
+ mu_t = self.distribution_estimation(Xt)
+
+ # store arrays of samples
+ self.Xs = Xs
+ self.Xt = Xt
# coupling estimation
if self.method == "sinkhorn":
@@ -1024,14 +1157,19 @@ class BaseTransport(BaseEstimator):
The transport source samples.
"""
- if self.mapping == "barycentric":
+ # TODO: check whether Xs is new or not
+ if self.Xs == Xs:
+ # perform standard barycentric mapping
transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None]
# set nans to 0
transp[~ np.isfinite(transp)] = 0
# compute transported samples
- transp_Xs = np.dot(transp, Xt)
+ transp_Xs = np.dot(transp, self.Xt)
+ else:
+ # perform out of sample mapping
+ print("out of sample mapping not yet implemented")
return transp_Xs
@@ -1053,16 +1191,19 @@ class BaseTransport(BaseEstimator):
The transported target samples.
"""
- if self.mapping == "barycentric":
+ # TODO: check whether Xt is new or not
+ if self.Xt == Xt:
+ # perform standard barycentric mapping
transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None]
# set nans to 0
transp_[~ np.isfinite(transp_)] = 0
# compute transported samples
- transp_Xt = np.dot(transp_, Xs)
+ transp_Xt = np.dot(transp_, self.Xs)
else:
- print("mapping not yet implemented")
+ # perform out of sample mapping
+ print("out of sample mapping not yet implemented")
return transp_Xt
@@ -1114,7 +1255,10 @@ class SinkhornTransport(BaseTransport):
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
tol=10e-9, verbose=False, log=False, mapping="barycentric",
- metric="sqeuclidean", distribution="uniform"):
+ metric="sqeuclidean",
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans'):
+
self.reg_e = reg_e
self.mode = mode
self.max_iter = max_iter
@@ -1123,8 +1267,9 @@ class SinkhornTransport(BaseTransport):
self.log = log
self.mapping = mapping
self.metric = metric
- self.distribution = distribution
+ self.distribution_estimation = distribution_estimation
self.method = "sinkhorn"
+ self.out_of_sample_map = out_of_sample_map
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples