diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 171 |
1 files changed, 167 insertions, 4 deletions
@@ -943,6 +943,46 @@ class BaseTransport(BaseEstimator): return transp_Xs + def transform_labels(self, ys=None): + """Propagate source labels ys to obtain estimated target labels + + Parameters + ---------- + ys : array-like, shape (n_source_samples,) + The class labels + + Returns + ------- + transp_ys : array-like, shape (n_target_samples,) + Estimated target labels. + """ + + # check the necessary inputs parameters are here + if check_params(ys=ys): + + classes = np.unique(ys) + n = len(classes) + D1 = np.zeros((n, len(ys))) + + # perform label propagation + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + if np.min(classes) != 0: + ys = ys - np.min(classes) + classes = np.unique(ys) + + for c in classes: + D1[int(c), ys == c] = 1 + + # compute transported samples + transp_ys = np.dot(D1, transp) + + return np.argmax(transp_ys,axis=0) + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports target samples Xt onto target samples Xs @@ -1010,6 +1050,44 @@ class BaseTransport(BaseEstimator): return transp_Xt + def inverse_transform_labels(self, yt=None): + """Propagate target labels yt to obtain estimated source labels ys + + Parameters + ---------- + yt : array-like, shape (n_target_samples,) + + Returns + ------- + transp_ys : array-like, shape (n_source_samples,) + Estimated source labels. + """ + + # check the necessary inputs parameters are here + if check_params(yt=yt): + + classes = np.unique(yt) + n = len(classes) + D1 = np.zeros((n, len(yt))) + + # perform label propagation + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + if np.min(classes) != 0: + yt = yt - np.min(classes) + classes = np.unique(yt) + + for c in classes: + D1[int(c), yt == c] = 1 + + # compute transported samples + transp_ys = np.dot(D1, transp.T) + + return np.argmax(transp_ys,axis=0) + class LinearTransport(BaseTransport): @@ -2017,10 +2095,10 @@ class JCPOTTransport(BaseTransport): Parameters ---------- - Xs : array-like, shape (n_source_samples, n_features) - The training input samples. - ys : array-like, shape (n_source_samples,) - The class labels + Xs : list of K array-like objects, shape K x (nk_source_samples, n_features) + A list of the training input samples. + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) @@ -2083,3 +2161,88 @@ class JCPOTTransport(BaseTransport): transp_Xs = np.concatenate(transp_Xs, axis=0) return transp_Xs + + def transform_labels(self, ys=None): + """Propagate source labels ys to obtain target labels + + Parameters + ---------- + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels + + Returns + ------- + yt : array-like, shape (n_target_samples,) + Estimated target labels. + """ + + # check the necessary inputs parameters are here + if check_params(ys=ys): + yt = np.zeros((len(np.unique(np.concatenate(ys))),self.xt_.shape[0])) + for i in range(len(ys)): + classes = np.unique(ys[i]) + n = len(classes) + ns = len(ys[i]) + + # perform label propagation + transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + if self.log: + D1 = self.log_['D1'][i] + else: + D1 = np.zeros((n, ns)) + + if np.min(classes) != 0: + ys = ys - np.min(classes) + classes = np.unique(ys) + + for c in classes: + D1[int(c), ys == c] = 1 + # compute transported samples + yt = yt + np.dot(D1, transp)/len(ys) + + return np.argmax(yt,axis=0) + + def inverse_transform_labels(self, yt=None): + """Propagate source labels ys to obtain target labels + + Parameters + ---------- + yt : array-like, shape (n_source_samples,) + The target class labels + + Returns + ------- + transp_ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of estimated source labels + """ + + # check the necessary inputs parameters are here + if check_params(yt=yt): + transp_ys = [] + classes = np.unique(yt) + n = len(classes) + D1 = np.zeros((n, len(yt))) + + if np.min(classes) != 0: + yt = yt - np.min(classes) + classes = np.unique(yt) + + for c in classes: + D1[int(c), yt == c] = 1 + + for i in range(len(self.xs_)): + + # perform label propagation + transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported labels + transp_ys.append(np.argmax(np.dot(D1, transp.T),axis=0)) + + return transp_ys |