diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/da.py | 174 | ||||
-rw-r--r-- | ot/utils.py | 21 |
2 files changed, 120 insertions, 75 deletions
@@ -14,7 +14,7 @@ import numpy as np from .bregman import sinkhorn from .lp import emd from .utils import unif, dist, kernel -from .utils import deprecated, BaseEstimator +from .utils import check_params, deprecated, BaseEstimator from .optim import cg from .optim import gcg @@ -954,6 +954,26 @@ def distribution_estimation_uniform(X): class BaseTransport(BaseEstimator): + """Base class for OTDA objects + + Notes + ----- + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). + + fit method should: + - estimate a cost matrix and store it in a `cost_` attribute + - estimate a coupling matrix and store it in a `coupling_` + attribute + - estimate distributions from source and target data and store them in + mu_s and mu_t attributes + - store Xs and Xt in attributes to be used later on in transform and + inverse_transform methods + + transform method should always get as input a Xs parameter + inverse_transform method should always get as input a Xt parameter + """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples @@ -976,7 +996,9 @@ class BaseTransport(BaseEstimator): Returns self. """ - if Xs is not None and Xt is not None: + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt): + # pairwise distance self.cost_ = dist(Xs, Xt, metric=self.metric) @@ -1003,14 +1025,10 @@ class BaseTransport(BaseEstimator): self.mu_t = self.distribution_estimation(Xt) # store arrays of samples - self.Xs = Xs - self.Xt = Xt + self.xs_ = Xs + self.xt_ = Xt - return self - else: - print("POT-Warning") - print("Please provide both Xs and Xt arguments when calling") - print("fit method") + return self def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples @@ -1058,8 +1076,11 @@ class BaseTransport(BaseEstimator): The transport source samples. """ - if Xs is not None: - if np.array_equal(self.Xs, Xs): + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + if np.array_equal(self.xs_, Xs): + # perform standard barycentric mapping transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] @@ -1067,7 +1088,7 @@ class BaseTransport(BaseEstimator): transp[~ np.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.Xt) + transp_Xs = np.dot(transp, self.xt_) else: # perform out of sample mapping indices = np.arange(Xs.shape[0]) @@ -1079,26 +1100,23 @@ class BaseTransport(BaseEstimator): for bi in batch_ind: # get the nearest neighbor in the source domain - D0 = dist(Xs[bi], self.Xs) + D0 = dist(Xs[bi], self.xs_) idx = np.argmin(D0, axis=1) # transport the source samples transp = self.coupling_ / np.sum( self.coupling_, 1)[:, None] transp[~ np.isfinite(transp)] = 0 - transp_Xs_ = np.dot(transp, self.Xt) + transp_Xs_ = np.dot(transp, self.xt_) # define the transported points - transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.Xs[idx, :] + transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :] transp_Xs.append(transp_Xs_) transp_Xs = np.concatenate(transp_Xs, axis=0) return transp_Xs - else: - print("POT-Warning") - print("Please provide Xs argument when calling transform method") def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): @@ -1123,8 +1141,11 @@ class BaseTransport(BaseEstimator): The transported target samples. """ - if Xt is not None: - if np.array_equal(self.Xt, Xt): + # check the necessary inputs parameters are here + if check_params(Xt=Xt): + + if np.array_equal(self.xt_, Xt): + # perform standard barycentric mapping transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None] @@ -1132,7 +1153,7 @@ class BaseTransport(BaseEstimator): transp_[~ np.isfinite(transp_)] = 0 # compute transported samples - transp_Xt = np.dot(transp_, self.Xs) + transp_Xt = np.dot(transp_, self.xs_) else: # perform out of sample mapping indices = np.arange(Xt.shape[0]) @@ -1143,26 +1164,23 @@ class BaseTransport(BaseEstimator): transp_Xt = [] for bi in batch_ind: - D0 = dist(Xt[bi], self.Xt) + D0 = dist(Xt[bi], self.xt_) idx = np.argmin(D0, axis=1) # transport the target samples transp_ = self.coupling_.T / np.sum( self.coupling_, 0)[:, None] transp_[~ np.isfinite(transp_)] = 0 - transp_Xt_ = np.dot(transp_, self.Xs) + transp_Xt_ = np.dot(transp_, self.xs_) # define the transported points - transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.Xt[idx, :] + transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :] transp_Xt.append(transp_Xt_) transp_Xt = np.concatenate(transp_Xt, axis=0) return transp_Xt - else: - print("POT-Warning") - print("Please provide Xt argument when calling inverse_transform") class SinkhornTransport(BaseTransport): @@ -1428,7 +1446,8 @@ class SinkhornLpl1Transport(BaseTransport): Returns self. """ - if Xs is not None and Xt is not None and ys is not None: + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt, ys=ys): super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt) @@ -1438,10 +1457,7 @@ class SinkhornLpl1Transport(BaseTransport): numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, verbose=self.verbose) - return self - else: - print("POT-Warning") - print("Please provide both Xs, Xt, ys arguments to fit method") + return self class SinkhornL1l2Transport(BaseTransport): @@ -1537,7 +1553,8 @@ class SinkhornL1l2Transport(BaseTransport): Returns self. """ - if Xs is not None and Xt is not None and ys is not None: + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt, ys=ys): super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt) @@ -1554,10 +1571,7 @@ class SinkhornL1l2Transport(BaseTransport): self.coupling_ = returned_ self.log_ = dict() - return self - else: - print("POT-Warning") - print("Please, provide both Xs, Xt and ys argument to fit method") + return self class MappingTransport(BaseEstimator): @@ -1652,29 +1666,35 @@ class MappingTransport(BaseEstimator): Returns self """ - self.Xs = Xs - self.Xt = Xt - - if self.kernel == "linear": - returned_ = joint_OT_mapping_linear( - Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, - verbose=self.verbose, verbose2=self.verbose2, - numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, - stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log) + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt): + + self.xs_ = Xs + self.xt_ = Xt + + if self.kernel == "linear": + returned_ = joint_OT_mapping_linear( + Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, + verbose=self.verbose, verbose2=self.verbose2, + numItermax=self.max_iter, + numInnerItermax=self.max_inner_iter, stopThr=self.tol, + stopInnerThr=self.inner_tol, log=self.log) + + elif self.kernel == "gaussian": + returned_ = joint_OT_mapping_kernel( + Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, + sigma=self.sigma, verbose=self.verbose, + verbose2=self.verbose, numItermax=self.max_iter, + numInnerItermax=self.max_inner_iter, + stopInnerThr=self.inner_tol, stopThr=self.tol, + log=self.log) - elif self.kernel == "gaussian": - returned_ = joint_OT_mapping_kernel( - Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, - sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose, - numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, - stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log) - - # deal with the value of log - if self.log: - self.coupling_, self.mapping_, self.log_ = returned_ - else: - self.coupling_, self.mapping_ = returned_ - self.log_ = dict() + # deal with the value of log + if self.log: + self.coupling_, self.mapping_, self.log_ = returned_ + else: + self.coupling_, self.mapping_ = returned_ + self.log_ = dict() return self @@ -1692,22 +1712,26 @@ class MappingTransport(BaseEstimator): The transport source samples. """ - if np.array_equal(self.Xs, Xs): - # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + # check the necessary inputs parameters are here + if check_params(Xs=Xs): - # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + if np.array_equal(self.xs_, Xs): + # perform standard barycentric mapping + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] - # compute transported samples - transp_Xs = np.dot(transp, self.Xt) - else: - if self.kernel == "gaussian": - K = kernel(Xs, self.Xs, method=self.kernel, sigma=self.sigma) - elif self.kernel == "linear": - K = Xs - if self.bias: - K = np.hstack((K, np.ones((Xs.shape[0], 1)))) - transp_Xs = K.dot(self.mapping_) + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 - return transp_Xs + # compute transported samples + transp_Xs = np.dot(transp, self.xt_) + else: + if self.kernel == "gaussian": + K = kernel(Xs, self.xs_, method=self.kernel, + sigma=self.sigma) + elif self.kernel == "linear": + K = Xs + if self.bias: + K = np.hstack((K, np.ones((Xs.shape[0], 1)))) + transp_Xs = K.dot(self.mapping_) + + return transp_Xs diff --git a/ot/utils.py b/ot/utils.py index 29ad536..01f2a67 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -168,6 +168,27 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()): return [x for i, x in sorted(res)] +def check_params(**kwargs): + """check_params: check whether some parameters are missing + """ + + missing_params = [] + check = True + + for param in kwargs: + if kwargs[param] is None: + missing_params.append(param) + + if len(missing_params) > 0: + print("POT - Warning: following necessary parameters are missing") + for p in missing_params: + print("\n", p) + + check = False + + return check + + class deprecated(object): """Decorator to mark a function or class as deprecated. |