summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py29
-rw-r--r--test/test_da.py32
2 files changed, 39 insertions, 22 deletions
diff --git a/ot/da.py b/ot/da.py
index 92a8f12..87d056d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1167,9 +1167,18 @@ class BaseTransport(BaseEstimator):
transp_Xs = np.dot(transp, self.Xt)
else:
# perform out of sample mapping
- print("Warning: out of sample mapping not yet implemented")
- print("input data will be returned")
- transp_Xs = Xs
+
+ # get the nearest neighbor in the source domain
+ D0 = dist(Xs, self.Xs)
+ idx = np.argmin(D0, axis=1)
+
+ # transport the source samples
+ transp = self.Coupling_ / np.sum(self.Coupling_, 1)[:, None]
+ transp[~ np.isfinite(transp)] = 0
+ transp_Xs_ = np.dot(transp, self.Xt)
+
+ # define the transported points
+ transp_Xs = transp_Xs_[idx, :] + Xs - self.Xs[idx, :]
return transp_Xs
@@ -1202,9 +1211,17 @@ class BaseTransport(BaseEstimator):
transp_Xt = np.dot(transp_, self.Xs)
else:
# perform out of sample mapping
- print("Warning: out of sample mapping not yet implemented")
- print("input data will be returned")
- transp_Xt = Xt
+
+ D0 = dist(Xt, self.Xt)
+ idx = np.argmin(D0, axis=1)
+
+ # transport the target samples
+ transp_ = self.Coupling_.T / np.sum(self.Coupling_, 0)[:, None]
+ transp_[~ np.isfinite(transp_)] = 0
+ transp_Xt_ = np.dot(transp_, self.Xs)
+
+ # define the transported points
+ transp_Xt = transp_Xt_[idx, :] + Xt - self.Xt[idx, :]
return transp_Xt
diff --git a/test/test_da.py b/test/test_da.py
index ecd2a3a..aed9f61 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -45,8 +45,8 @@ def test_sinkhorn_lpl1_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -55,8 +55,8 @@ def test_sinkhorn_lpl1_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
@@ -108,8 +108,8 @@ def test_sinkhorn_l1l2_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -118,8 +118,8 @@ def test_sinkhorn_l1l2_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
@@ -171,8 +171,8 @@ def test_sinkhorn_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -181,8 +181,8 @@ def test_sinkhorn_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
@@ -234,8 +234,8 @@ def test_emd_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -244,8 +244,8 @@ def test_emd_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)