summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-28 14:07:55 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-28 14:07:55 +0200
commitc5d7c40c1d850879abd5f2c513afa1a2c5d5987e (patch)
tree341b525dffe325be2709d72e0f0863ad9bfc66de
parenta8fa91bec26caa93329e61a104e0ad6afdf37363 (diff)
check input parameters with helper functions
-rw-r--r--ot/da.py174
-rw-r--r--ot/utils.py21
2 files changed, 120 insertions, 75 deletions
diff --git a/ot/da.py b/ot/da.py
index 369b6a2..78dc150 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -14,7 +14,7 @@ import numpy as np
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel
-from .utils import deprecated, BaseEstimator
+from .utils import check_params, deprecated, BaseEstimator
from .optim import cg
from .optim import gcg
@@ -954,6 +954,26 @@ def distribution_estimation_uniform(X):
class BaseTransport(BaseEstimator):
+ """Base class for OTDA objects
+
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+
+ fit method should:
+ - estimate a cost matrix and store it in a `cost_` attribute
+ - estimate a coupling matrix and store it in a `coupling_`
+ attribute
+ - estimate distributions from source and target data and store them in
+ mu_s and mu_t attributes
+ - store Xs and Xt in attributes to be used later on in transform and
+ inverse_transform methods
+
+ transform method should always get as input a Xs parameter
+ inverse_transform method should always get as input a Xt parameter
+ """
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
@@ -976,7 +996,9 @@ class BaseTransport(BaseEstimator):
Returns self.
"""
- if Xs is not None and Xt is not None:
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
# pairwise distance
self.cost_ = dist(Xs, Xt, metric=self.metric)
@@ -1003,14 +1025,10 @@ class BaseTransport(BaseEstimator):
self.mu_t = self.distribution_estimation(Xt)
# store arrays of samples
- self.Xs = Xs
- self.Xt = Xt
+ self.xs_ = Xs
+ self.xt_ = Xt
- return self
- else:
- print("POT-Warning")
- print("Please provide both Xs and Xt arguments when calling")
- print("fit method")
+ return self
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
@@ -1058,8 +1076,11 @@ class BaseTransport(BaseEstimator):
The transport source samples.
"""
- if Xs is not None:
- if np.array_equal(self.Xs, Xs):
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ if np.array_equal(self.xs_, Xs):
+
# perform standard barycentric mapping
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
@@ -1067,7 +1088,7 @@ class BaseTransport(BaseEstimator):
transp[~ np.isfinite(transp)] = 0
# compute transported samples
- transp_Xs = np.dot(transp, self.Xt)
+ transp_Xs = np.dot(transp, self.xt_)
else:
# perform out of sample mapping
indices = np.arange(Xs.shape[0])
@@ -1079,26 +1100,23 @@ class BaseTransport(BaseEstimator):
for bi in batch_ind:
# get the nearest neighbor in the source domain
- D0 = dist(Xs[bi], self.Xs)
+ D0 = dist(Xs[bi], self.xs_)
idx = np.argmin(D0, axis=1)
# transport the source samples
transp = self.coupling_ / np.sum(
self.coupling_, 1)[:, None]
transp[~ np.isfinite(transp)] = 0
- transp_Xs_ = np.dot(transp, self.Xt)
+ transp_Xs_ = np.dot(transp, self.xt_)
# define the transported points
- transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.Xs[idx, :]
+ transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :]
transp_Xs.append(transp_Xs_)
transp_Xs = np.concatenate(transp_Xs, axis=0)
return transp_Xs
- else:
- print("POT-Warning")
- print("Please provide Xs argument when calling transform method")
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
@@ -1123,8 +1141,11 @@ class BaseTransport(BaseEstimator):
The transported target samples.
"""
- if Xt is not None:
- if np.array_equal(self.Xt, Xt):
+ # check the necessary inputs parameters are here
+ if check_params(Xt=Xt):
+
+ if np.array_equal(self.xt_, Xt):
+
# perform standard barycentric mapping
transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
@@ -1132,7 +1153,7 @@ class BaseTransport(BaseEstimator):
transp_[~ np.isfinite(transp_)] = 0
# compute transported samples
- transp_Xt = np.dot(transp_, self.Xs)
+ transp_Xt = np.dot(transp_, self.xs_)
else:
# perform out of sample mapping
indices = np.arange(Xt.shape[0])
@@ -1143,26 +1164,23 @@ class BaseTransport(BaseEstimator):
transp_Xt = []
for bi in batch_ind:
- D0 = dist(Xt[bi], self.Xt)
+ D0 = dist(Xt[bi], self.xt_)
idx = np.argmin(D0, axis=1)
# transport the target samples
transp_ = self.coupling_.T / np.sum(
self.coupling_, 0)[:, None]
transp_[~ np.isfinite(transp_)] = 0
- transp_Xt_ = np.dot(transp_, self.Xs)
+ transp_Xt_ = np.dot(transp_, self.xs_)
# define the transported points
- transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.Xt[idx, :]
+ transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :]
transp_Xt.append(transp_Xt_)
transp_Xt = np.concatenate(transp_Xt, axis=0)
return transp_Xt
- else:
- print("POT-Warning")
- print("Please provide Xt argument when calling inverse_transform")
class SinkhornTransport(BaseTransport):
@@ -1428,7 +1446,8 @@ class SinkhornLpl1Transport(BaseTransport):
Returns self.
"""
- if Xs is not None and Xt is not None and ys is not None:
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt, ys=ys):
super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
@@ -1438,10 +1457,7 @@ class SinkhornLpl1Transport(BaseTransport):
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
verbose=self.verbose)
- return self
- else:
- print("POT-Warning")
- print("Please provide both Xs, Xt, ys arguments to fit method")
+ return self
class SinkhornL1l2Transport(BaseTransport):
@@ -1537,7 +1553,8 @@ class SinkhornL1l2Transport(BaseTransport):
Returns self.
"""
- if Xs is not None and Xt is not None and ys is not None:
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt, ys=ys):
super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)
@@ -1554,10 +1571,7 @@ class SinkhornL1l2Transport(BaseTransport):
self.coupling_ = returned_
self.log_ = dict()
- return self
- else:
- print("POT-Warning")
- print("Please, provide both Xs, Xt and ys argument to fit method")
+ return self
class MappingTransport(BaseEstimator):
@@ -1652,29 +1666,35 @@ class MappingTransport(BaseEstimator):
Returns self
"""
- self.Xs = Xs
- self.Xt = Xt
-
- if self.kernel == "linear":
- returned_ = joint_OT_mapping_linear(
- Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
- verbose=self.verbose, verbose2=self.verbose2,
- numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
- stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log)
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ self.xs_ = Xs
+ self.xt_ = Xt
+
+ if self.kernel == "linear":
+ returned_ = joint_OT_mapping_linear(
+ Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
+ verbose=self.verbose, verbose2=self.verbose2,
+ numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter, stopThr=self.tol,
+ stopInnerThr=self.inner_tol, log=self.log)
+
+ elif self.kernel == "gaussian":
+ returned_ = joint_OT_mapping_kernel(
+ Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
+ sigma=self.sigma, verbose=self.verbose,
+ verbose2=self.verbose, numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter,
+ stopInnerThr=self.inner_tol, stopThr=self.tol,
+ log=self.log)
- elif self.kernel == "gaussian":
- returned_ = joint_OT_mapping_kernel(
- Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
- sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose,
- numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
- stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log)
-
- # deal with the value of log
- if self.log:
- self.coupling_, self.mapping_, self.log_ = returned_
- else:
- self.coupling_, self.mapping_ = returned_
- self.log_ = dict()
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.mapping_, self.log_ = returned_
+ else:
+ self.coupling_, self.mapping_ = returned_
+ self.log_ = dict()
return self
@@ -1692,22 +1712,26 @@ class MappingTransport(BaseEstimator):
The transport source samples.
"""
- if np.array_equal(self.Xs, Xs):
- # perform standard barycentric mapping
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
- # set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ if np.array_equal(self.xs_, Xs):
+ # perform standard barycentric mapping
+ transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
- # compute transported samples
- transp_Xs = np.dot(transp, self.Xt)
- else:
- if self.kernel == "gaussian":
- K = kernel(Xs, self.Xs, method=self.kernel, sigma=self.sigma)
- elif self.kernel == "linear":
- K = Xs
- if self.bias:
- K = np.hstack((K, np.ones((Xs.shape[0], 1))))
- transp_Xs = K.dot(self.mapping_)
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
- return transp_Xs
+ # compute transported samples
+ transp_Xs = np.dot(transp, self.xt_)
+ else:
+ if self.kernel == "gaussian":
+ K = kernel(Xs, self.xs_, method=self.kernel,
+ sigma=self.sigma)
+ elif self.kernel == "linear":
+ K = Xs
+ if self.bias:
+ K = np.hstack((K, np.ones((Xs.shape[0], 1))))
+ transp_Xs = K.dot(self.mapping_)
+
+ return transp_Xs
diff --git a/ot/utils.py b/ot/utils.py
index 29ad536..01f2a67 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -168,6 +168,27 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
return [x for i, x in sorted(res)]
+def check_params(**kwargs):
+ """check_params: check whether some parameters are missing
+ """
+
+ missing_params = []
+ check = True
+
+ for param in kwargs:
+ if kwargs[param] is None:
+ missing_params.append(param)
+
+ if len(missing_params) > 0:
+ print("POT - Warning: following necessary parameters are missing")
+ for p in missing_params:
+ print("\n", p)
+
+ check = False
+
+ return check
+
+
class deprecated(object):
"""Decorator to mark a function or class as deprecated.