From b8672f67639e9daa3f91e555581256f984115f56 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Fri, 4 Aug 2017 14:55:54 +0200 Subject: out of samples by Ferradans supported for transform and inverse_transform --- ot/da.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index 92a8f12..87d056d 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1167,9 +1167,18 @@ class BaseTransport(BaseEstimator): transp_Xs = np.dot(transp, self.Xt) else: # perform out of sample mapping - print("Warning: out of sample mapping not yet implemented") - print("input data will be returned") - transp_Xs = Xs + + # get the nearest neighbor in the source domain + D0 = dist(Xs, self.Xs) + idx = np.argmin(D0, axis=1) + + # transport the source samples + transp = self.Coupling_ / np.sum(self.Coupling_, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_ = np.dot(transp, self.Xt) + + # define the transported points + transp_Xs = transp_Xs_[idx, :] + Xs - self.Xs[idx, :] return transp_Xs @@ -1202,9 +1211,17 @@ class BaseTransport(BaseEstimator): transp_Xt = np.dot(transp_, self.Xs) else: # perform out of sample mapping - print("Warning: out of sample mapping not yet implemented") - print("input data will be returned") - transp_Xt = Xt + + D0 = dist(Xt, self.Xt) + idx = np.argmin(D0, axis=1) + + # transport the target samples + transp_ = self.Coupling_.T / np.sum(self.Coupling_, 0)[:, None] + transp_[~ np.isfinite(transp_)] = 0 + transp_Xt_ = np.dot(transp_, self.Xs) + + # define the transported points + transp_Xt = transp_Xt_[idx, :] + Xt - self.Xt[idx, :] return transp_Xt -- cgit v1.2.3