diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-07-28 14:52:36 +0200 |
---|---|---|
committer | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2017-09-01 11:09:13 +0200 |
commit | fa36e775ff2c1fe17cf1323d430a196774a6d2a5 (patch) | |
tree | dbda4e899026ab6175c5f847f5bf59a980a15b2e /ot/da.py | |
parent | f469205cf19915a34366c9410ccf6c8e13038673 (diff) |
small modifs according to NG proposals
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 43 |
1 files changed, 31 insertions, 12 deletions
@@ -926,8 +926,8 @@ from sklearn.metrics import pairwise_distances """ - all methods have the same input parameters: Xs, Xt, ys, yt (what order ?) -- ref: is the entropic reg parameter -- eta: is the second reg parameter +- reg_e: is the entropic reg parameter +- reg_cl: is the second reg parameter - gamma_: is the optimal coupling - mapping barycentric for the moment @@ -940,7 +940,7 @@ Questions: class BaseTransport(BaseEstimator): - def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"): + def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None): """fit: estimates the optimal coupling Parameters: @@ -964,13 +964,17 @@ class BaseTransport(BaseEstimator): print("TODO: modify cost matrix accordingly") pass - # distribution estimation: should we change it ? - mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0]) - mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0]) + # distribution estimation + if self.distribution == "uniform": + mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0]) + mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0]) + else: + print("TODO: implement kernelized approach") + # coupling estimation if method == "sinkhorn": self.gamma_ = sinkhorn( - a=mu_s, b=mu_t, M=Cost, reg=self.reg, + a=mu_s, b=mu_t, M=Cost, reg=self.reg_e, numItermax=self.max_iter, stopThr=self.tol, verbose=self.verbose, log=self.log) else: @@ -1058,7 +1062,7 @@ class SinkhornTransport(BaseTransport): Parameters ---------- - - reg : parameter for entropic regularization + - reg_e : parameter for entropic regularization - mode: unsupervised (default) or semi supervised: controls whether labels are taken into accout to construct the optimal coupling - max_iter : maximum number of iterations @@ -1071,10 +1075,10 @@ class SinkhornTransport(BaseTransport): - gamma_: optimal coupling estimated by the fit function """ - def __init__(self, reg=1., mode="unsupervised", max_iter=1000, + def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000, tol=10e-9, verbose=False, log=False, mapping="barycentric", - metric="sqeuclidean"): - self.reg = reg + metric="sqeuclidean", distribution="uniform"): + self.reg_e = reg_e self.mode = mode self.max_iter = max_iter self.tol = tol @@ -1082,11 +1086,26 @@ class SinkhornTransport(BaseTransport): self.log = log self.mapping = mapping self.metric = metric + self.distribution = distribution self.method = "sinkhorn" def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """_fit + """fit + + Parameters: + ----------- + - Xs: source samples, (ns samples, d features) numpy-like array + - ys: source labels + - Xt: target samples (nt samples, d features) numpy-like array + - yt: target labels + - method: algorithm to use to compute optimal coupling + (default: sinkhorn) + + Returns: + -------- + - self """ + return super(SinkhornTransport, self).fit( Xs, ys, Xt, yt, method=self.method) |