diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/da.py | 89 |
1 files changed, 58 insertions, 31 deletions
@@ -1147,7 +1147,7 @@ class BaseTransport(BaseEstimator): return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt) - def transform(self, Xs=None, ys=None, Xt=None, yt=None): + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports source samples Xs onto target ones Xt Parameters @@ -1160,6 +1160,8 @@ class BaseTransport(BaseEstimator): The training input samples. yt : array-like, shape (n_labeled_target_samples,) The class labels + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform Returns ------- @@ -1178,34 +1180,48 @@ class BaseTransport(BaseEstimator): transp_Xs = np.dot(transp, self.Xt) else: # perform out of sample mapping + indices = np.arange(Xs.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] - # get the nearest neighbor in the source domain - D0 = dist(Xs, self.Xs) - idx = np.argmin(D0, axis=1) + transp_Xs = [] + for bi in batch_ind: - # transport the source samples - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_ = np.dot(transp, self.Xt) + # get the nearest neighbor in the source domain + D0 = dist(Xs[bi], self.Xs) + idx = np.argmin(D0, axis=1) + + # transport the source samples + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_ = np.dot(transp, self.Xt) - # define the transported points - transp_Xs = transp_Xs_[idx, :] + Xs - self.Xs[idx, :] + # define the transported points + transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.Xs[idx, :] + + transp_Xs.append(transp_Xs_) + + transp_Xs = np.concatenate(transp_Xs, axis=0) return transp_Xs - def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None): + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, + batch_size=128): """Transports target samples Xt onto target samples Xs Parameters ---------- Xs : array-like, shape (n_source_samples, n_features) The training input samples. - ys : array-like, shape = (n_source_samples,) + ys : array-like, shape (n_source_samples,) The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape = (n_labeled_target_samples,) + yt : array-like, shape (n_labeled_target_samples,) The class labels + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform Returns ------- @@ -1224,17 +1240,28 @@ class BaseTransport(BaseEstimator): transp_Xt = np.dot(transp_, self.Xs) else: # perform out of sample mapping + indices = np.arange(Xt.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] - D0 = dist(Xt, self.Xt) - idx = np.argmin(D0, axis=1) + transp_Xt = [] + for bi in batch_ind: - # transport the target samples - transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None] - transp_[~ np.isfinite(transp_)] = 0 - transp_Xt_ = np.dot(transp_, self.Xs) + D0 = dist(Xt[bi], self.Xt) + idx = np.argmin(D0, axis=1) + + # transport the target samples + transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None] + transp_[~ np.isfinite(transp_)] = 0 + transp_Xt_ = np.dot(transp_, self.Xs) + + # define the transported points + transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.Xt[idx, :] - # define the transported points - transp_Xt = transp_Xt_[idx, :] + Xt - self.Xt[idx, :] + transp_Xt.append(transp_Xt_) + + transp_Xt = np.concatenate(transp_Xt, axis=0) return transp_Xt @@ -1306,11 +1333,11 @@ class SinkhornTransport(BaseTransport): ---------- Xs : array-like, shape (n_source_samples, n_features) The training input samples. - ys : array-like, shape = (n_source_samples,) + ys : array-like, shape (n_source_samples,) The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape = (n_labeled_target_samples,) + yt : array-like, shape (n_labeled_target_samples,) The class labels Returns @@ -1381,11 +1408,11 @@ class EMDTransport(BaseTransport): ---------- Xs : array-like, shape (n_source_samples, n_features) The training input samples. - ys : array-like, shape = (n_source_samples,) + ys : array-like, shape (n_source_samples,) The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape = (n_labeled_target_samples,) + yt : array-like, shape (n_labeled_target_samples,) The class labels Returns @@ -1480,11 +1507,11 @@ class SinkhornLpl1Transport(BaseTransport): ---------- Xs : array-like, shape (n_source_samples, n_features) The training input samples. - ys : array-like, shape = (n_source_samples,) + ys : array-like, shape (n_source_samples,) The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape = (n_labeled_target_samples,) + yt : array-like, shape (n_labeled_target_samples,) The class labels Returns @@ -1581,11 +1608,11 @@ class SinkhornL1l2Transport(BaseTransport): ---------- Xs : array-like, shape (n_source_samples, n_features) The training input samples. - ys : array-like, shape = (n_source_samples,) + ys : array-like, shape (n_source_samples,) The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape = (n_labeled_target_samples,) + yt : array-like, shape (n_labeled_target_samples,) The class labels Returns @@ -1675,11 +1702,11 @@ class MappingTransport(BaseEstimator): ---------- Xs : array-like, shape (n_source_samples, n_features) The training input samples. - ys : array-like, shape = (n_source_samples,) + ys : array-like, shape (n_source_samples,) The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. - yt : array-like, shape = (n_labeled_target_samples,) + yt : array-like, shape (n_labeled_target_samples,) The class labels Returns |