From ca9c9d6d8ecef6a38e0fd6240538a8af35ad06f5 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Fri, 28 Jul 2017 08:49:10 +0200 Subject: first proposal for OT wrappers --- ot/da.py | 179 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 1dd4011..f534bf5 100644 --- a/ot/da.py +++ b/ot/da.py @@ -916,3 +916,182 @@ class OTDA_mapping_kernel(OTDA_mapping_linear): else: print("Warning, model not fitted yet, returning None") return None + +############################################################################## +# proposal +############################################################################## + +from sklearn.base import BaseEstimator +from sklearn.metrics import pairwise_distances + +""" +- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?) +- ref: is the entropic reg parameter +- eta: is the second reg parameter +- gamma_: is the optimal coupling +- mapping barycentric for the moment + +Questions: +- Cost matrix estimation: from sklearn or from internal function ? +- distribution estimation ? Look at Nathalie's approach +- should everything been done into the fit from BaseTransport ? +""" + + +class BaseTransport(BaseEstimator): + + def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"): + """fit: estimates the optimal coupling + + Parameters: + ----------- + - Xs: source samples, (ns samples, d features) numpy-like array + - ys: source labels + - Xt: target samples (nt samples, d features) numpy-like array + - yt: target labels + - method: algorithm to use to compute optimal coupling + (default: sinkhorn) + + Returns: + -------- + - self + """ + + # pairwise distance + Cost = pairwise_distances(Xs, Xt, metric=self.metric) + + if self.mode == "semisupervised": + print("TODO: modify cost matrix accordingly") + pass + + # distribution estimation: should we change it ? + mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0]) + mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0]) + + if method == "sinkhorn": + self.gamma_ = sinkhorn( + a=mu_s, b=mu_t, M=Cost, reg=self.reg, + numItermax=self.max_iter, stopThr=self.tol, + verbose=self.verbose, log=self.log) + else: + print("TODO: implement the other methods") + + return self + + def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): + """fit_transform + + Parameters: + ----------- + - Xs: source samples, (ns samples, d features) numpy-like array + - ys: source labels + - Xt: target samples (nt samples, d features) numpy-like array + - yt: target labels + + Returns: + -------- + - transp_Xt + """ + + return self.fit(Xs, ys, Xt, yt, self.method).transform(Xs, ys, Xt, yt) + + def transform(self, Xs=None, ys=None, Xt=None, yt=None): + """transform: as a convention transports source samples + onto target samples + + Parameters: + ----------- + - Xs: source samples, (ns samples, d features) numpy-like array + - ys: source labels + - Xt: target samples (nt samples, d features) numpy-like array + - yt: target labels + + Returns: + -------- + - transp_Xt + """ + + if self.mapping == "barycentric": + transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported samples + transp_Xs = np.dot(transp, Xt) + + return transp_Xs + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None): + """inverse_transform: as a convention transports target samples + onto source samples + + Parameters: + ----------- + - Xs: source samples, (ns samples, d features) numpy-like array + - ys: source labels + - Xt: target samples (nt samples, d features) numpy-like array + - yt: target labels + + Returns: + -------- + - transp_Xt + """ + + if self.mapping == "barycentric": + transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None] + + # set nans to 0 + transp_[~ np.isfinite(transp_)] = 0 + + # compute transported samples + transp_Xt = np.dot(transp_, Xs) + else: + print("mapping not yet implemented") + + return transp_Xt + + +class SinkhornTransport(BaseTransport): + """SinkhornTransport: class wrapper for optimal transport based on + Sinkhorn's algorithm + + Parameters + ---------- + - reg : parameter for entropic regularization + - mode: unsupervised (default) or semi supervised: controls whether + labels are taken into accout to construct the optimal coupling + - max_iter : maximum number of iterations + - tol : precision + - verbose : control verbosity + - log : control log + + Attributes + ---------- + - gamma_: optimal coupling estimated by the fit function + """ + + def __init__(self, reg=1., mode="unsupervised", max_iter=1000, + tol=10e-9, verbose=False, log=False, mapping="barycentric", + metric="sqeuclidean"): + self.reg = reg + self.mode = mode + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.mapping = mapping + self.metric = metric + self.method = "sinkhorn" + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + """_fit + """ + return super(SinkhornTransport, self).fit( + Xs, ys, Xt, yt, method=self.method) + + +if __name__ == "__main__": + print("Small test") + + st = SinkhornTransport() -- cgit v1.2.3