diff options
-rw-r--r-- | ot/da.py | 43 |
1 files changed, 31 insertions, 12 deletions
@@ -926,8 +926,8 @@ from sklearn.metrics import pairwise_distances """ - all methods have the same input parameters: Xs, Xt, ys, yt (what order ?) -- ref: is the entropic reg parameter -- eta: is the second reg parameter +- reg_e: is the entropic reg parameter +- reg_cl: is the second reg parameter - gamma_: is the optimal coupling - mapping barycentric for the moment @@ -940,7 +940,7 @@ Questions: class BaseTransport(BaseEstimator): - def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"): + def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None): """fit: estimates the optimal coupling Parameters: @@ -964,13 +964,17 @@ class BaseTransport(BaseEstimator): print("TODO: modify cost matrix accordingly") pass - # distribution estimation: should we change it ? - mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0]) - mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0]) + # distribution estimation + if self.distribution == "uniform": + mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0]) + mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0]) + else: + print("TODO: implement kernelized approach") + # coupling estimation if method == "sinkhorn": self.gamma_ = sinkhorn( - a=mu_s, b=mu_t, M=Cost, reg=self.reg, + a=mu_s, b=mu_t, M=Cost, reg=self.reg_e, numItermax=self.max_iter, stopThr=self.tol, verbose=self.verbose, log=self.log) else: @@ -1058,7 +1062,7 @@ class SinkhornTransport(BaseTransport): Parameters ---------- - - reg : parameter for entropic regularization + - reg_e : parameter for entropic regularization - mode: unsupervised (default) or semi supervised: controls whether labels are taken into accout to construct the optimal coupling - max_iter : maximum number of iterations @@ -1071,10 +1075,10 @@ class SinkhornTransport(BaseTransport): - gamma_: optimal coupling estimated by the fit function """ - def __init__(self, reg=1., mode="unsupervised", max_iter=1000, + def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000, tol=10e-9, verbose=False, log=False, mapping="barycentric", - metric="sqeuclidean"): - self.reg = reg + metric="sqeuclidean", distribution="uniform"): + self.reg_e = reg_e self.mode = mode self.max_iter = max_iter self.tol = tol @@ -1082,11 +1086,26 @@ class SinkhornTransport(BaseTransport): self.log = log self.mapping = mapping self.metric = metric + self.distribution = distribution self.method = "sinkhorn" def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """_fit + """fit + + Parameters: + ----------- + - Xs: source samples, (ns samples, d features) numpy-like array + - ys: source labels + - Xt: target samples (nt samples, d features) numpy-like array + - yt: target labels + - method: algorithm to use to compute optimal coupling + (default: sinkhorn) + + Returns: + -------- + - self """ + return super(SinkhornTransport, self).fit( Xs, ys, Xt, yt, method=self.method) |