summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-04 14:55:54 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-04 14:55:54 +0200
commit738bfb1c560ff4e349f5083fc1f81a54e4be4980 (patch)
treead079d5f2f94985f504140973d4d2418c21245d4 /ot/da.py
parent778f4f76d7f162e7630c9ba5369a0e389e18433c (diff)
out of samples by Ferradans supported for transform and inverse_transform
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/ot/da.py b/ot/da.py
index 92a8f12..87d056d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1167,9 +1167,18 @@ class BaseTransport(BaseEstimator):
transp_Xs = np.dot(transp, self.Xt)
else:
# perform out of sample mapping
- print("Warning: out of sample mapping not yet implemented")
- print("input data will be returned")
- transp_Xs = Xs
+
+ # get the nearest neighbor in the source domain
+ D0 = dist(Xs, self.Xs)
+ idx = np.argmin(D0, axis=1)
+
+ # transport the source samples
+ transp = self.Coupling_ / np.sum(self.Coupling_, 1)[:, None]
+ transp[~ np.isfinite(transp)] = 0
+ transp_Xs_ = np.dot(transp, self.Xt)
+
+ # define the transported points
+ transp_Xs = transp_Xs_[idx, :] + Xs - self.Xs[idx, :]
return transp_Xs
@@ -1202,9 +1211,17 @@ class BaseTransport(BaseEstimator):
transp_Xt = np.dot(transp_, self.Xs)
else:
# perform out of sample mapping
- print("Warning: out of sample mapping not yet implemented")
- print("input data will be returned")
- transp_Xt = Xt
+
+ D0 = dist(Xt, self.Xt)
+ idx = np.argmin(D0, axis=1)
+
+ # transport the target samples
+ transp_ = self.Coupling_.T / np.sum(self.Coupling_, 0)[:, None]
+ transp_[~ np.isfinite(transp_)] = 0
+ transp_Xt_ = np.dot(transp_, self.Xs)
+
+ # define the transported points
+ transp_Xt = transp_Xt_[idx, :] + Xt - self.Xt[idx, :]
return transp_Xt