summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-07-28 08:49:10 +0200
committerSlasnista <stan.chambon@gmail.com>2017-07-28 08:49:10 +0200
commitca9c9d6d8ecef6a38e0fd6240538a8af35ad06f5 (patch)
tree2f135d251a4002cdfb2cbaede59079d82794ca52 /ot/da.py
parent553a45678c829896cbb076b8a89934525431c62c (diff)
first proposal for OT wrappers
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/ot/da.py b/ot/da.py
index 1dd4011..f534bf5 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -916,3 +916,182 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
else:
print("Warning, model not fitted yet, returning None")
return None
+
+##############################################################################
+# proposal
+##############################################################################
+
+from sklearn.base import BaseEstimator
+from sklearn.metrics import pairwise_distances
+
+"""
+- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
+- ref: is the entropic reg parameter
+- eta: is the second reg parameter
+- gamma_: is the optimal coupling
+- mapping barycentric for the moment
+
+Questions:
+- Cost matrix estimation: from sklearn or from internal function ?
+- distribution estimation ? Look at Nathalie's approach
+- should everything been done into the fit from BaseTransport ?
+"""
+
+
+class BaseTransport(BaseEstimator):
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"):
+ """fit: estimates the optimal coupling
+
+ Parameters:
+ -----------
+ - Xs: source samples, (ns samples, d features) numpy-like array
+ - ys: source labels
+ - Xt: target samples (nt samples, d features) numpy-like array
+ - yt: target labels
+ - method: algorithm to use to compute optimal coupling
+ (default: sinkhorn)
+
+ Returns:
+ --------
+ - self
+ """
+
+ # pairwise distance
+ Cost = pairwise_distances(Xs, Xt, metric=self.metric)
+
+ if self.mode == "semisupervised":
+ print("TODO: modify cost matrix accordingly")
+ pass
+
+ # distribution estimation: should we change it ?
+ mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0])
+ mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0])
+
+ if method == "sinkhorn":
+ self.gamma_ = sinkhorn(
+ a=mu_s, b=mu_t, M=Cost, reg=self.reg,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+ else:
+ print("TODO: implement the other methods")
+
+ return self
+
+ def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
+ """fit_transform
+
+ Parameters:
+ -----------
+ - Xs: source samples, (ns samples, d features) numpy-like array
+ - ys: source labels
+ - Xt: target samples (nt samples, d features) numpy-like array
+ - yt: target labels
+
+ Returns:
+ --------
+ - transp_Xt
+ """
+
+ return self.fit(Xs, ys, Xt, yt, self.method).transform(Xs, ys, Xt, yt)
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
+ """transform: as a convention transports source samples
+ onto target samples
+
+ Parameters:
+ -----------
+ - Xs: source samples, (ns samples, d features) numpy-like array
+ - ys: source labels
+ - Xt: target samples (nt samples, d features) numpy-like array
+ - yt: target labels
+
+ Returns:
+ --------
+ - transp_Xt
+ """
+
+ if self.mapping == "barycentric":
+ transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ # compute transported samples
+ transp_Xs = np.dot(transp, Xt)
+
+ return transp_Xs
+
+ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
+ """inverse_transform: as a convention transports target samples
+ onto source samples
+
+ Parameters:
+ -----------
+ - Xs: source samples, (ns samples, d features) numpy-like array
+ - ys: source labels
+ - Xt: target samples (nt samples, d features) numpy-like array
+ - yt: target labels
+
+ Returns:
+ --------
+ - transp_Xt
+ """
+
+ if self.mapping == "barycentric":
+ transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None]
+
+ # set nans to 0
+ transp_[~ np.isfinite(transp_)] = 0
+
+ # compute transported samples
+ transp_Xt = np.dot(transp_, Xs)
+ else:
+ print("mapping not yet implemented")
+
+ return transp_Xt
+
+
+class SinkhornTransport(BaseTransport):
+ """SinkhornTransport: class wrapper for optimal transport based on
+ Sinkhorn's algorithm
+
+ Parameters
+ ----------
+ - reg : parameter for entropic regularization
+ - mode: unsupervised (default) or semi supervised: controls whether
+ labels are taken into accout to construct the optimal coupling
+ - max_iter : maximum number of iterations
+ - tol : precision
+ - verbose : control verbosity
+ - log : control log
+
+ Attributes
+ ----------
+ - gamma_: optimal coupling estimated by the fit function
+ """
+
+ def __init__(self, reg=1., mode="unsupervised", max_iter=1000,
+ tol=10e-9, verbose=False, log=False, mapping="barycentric",
+ metric="sqeuclidean"):
+ self.reg = reg
+ self.mode = mode
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.mapping = mapping
+ self.metric = metric
+ self.method = "sinkhorn"
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """_fit
+ """
+ return super(SinkhornTransport, self).fit(
+ Xs, ys, Xt, yt, method=self.method)
+
+
+if __name__ == "__main__":
+ print("Small test")
+
+ st = SinkhornTransport()