diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-08-04 14:55:54 +0200 |
---|---|---|
committer | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2017-09-01 11:09:13 +0200 |
commit | b8672f67639e9daa3f91e555581256f984115f56 (patch) | |
tree | ad079d5f2f94985f504140973d4d2418c21245d4 /ot/da.py | |
parent | 266abb6c9a0fa53e419d72b99d1906cdf78a8009 (diff) |
out of samples by Ferradans supported for transform and inverse_transform
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 29 |
1 files changed, 23 insertions, 6 deletions
@@ -1167,9 +1167,18 @@ class BaseTransport(BaseEstimator): transp_Xs = np.dot(transp, self.Xt) else: # perform out of sample mapping - print("Warning: out of sample mapping not yet implemented") - print("input data will be returned") - transp_Xs = Xs + + # get the nearest neighbor in the source domain + D0 = dist(Xs, self.Xs) + idx = np.argmin(D0, axis=1) + + # transport the source samples + transp = self.Coupling_ / np.sum(self.Coupling_, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_ = np.dot(transp, self.Xt) + + # define the transported points + transp_Xs = transp_Xs_[idx, :] + Xs - self.Xs[idx, :] return transp_Xs @@ -1202,9 +1211,17 @@ class BaseTransport(BaseEstimator): transp_Xt = np.dot(transp_, self.Xs) else: # perform out of sample mapping - print("Warning: out of sample mapping not yet implemented") - print("input data will be returned") - transp_Xt = Xt + + D0 = dist(Xt, self.Xt) + idx = np.argmin(D0, axis=1) + + # transport the target samples + transp_ = self.Coupling_.T / np.sum(self.Coupling_, 0)[:, None] + transp_[~ np.isfinite(transp_)] = 0 + transp_Xt_ = np.dot(transp_, self.Xs) + + # define the transported points + transp_Xt = transp_Xt_[idx, :] + Xt - self.Xt[idx, :] return transp_Xt |