diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-08-04 14:55:54 +0200 |
---|---|---|
committer | Slasnista <stan.chambon@gmail.com> | 2017-08-04 14:55:54 +0200 |
commit | 738bfb1c560ff4e349f5083fc1f81a54e4be4980 (patch) | |
tree | ad079d5f2f94985f504140973d4d2418c21245d4 | |
parent | 778f4f76d7f162e7630c9ba5369a0e389e18433c (diff) |
out of samples by Ferradans supported for transform and inverse_transform
-rw-r--r-- | ot/da.py | 29 | ||||
-rw-r--r-- | test/test_da.py | 32 |
2 files changed, 39 insertions, 22 deletions
@@ -1167,9 +1167,18 @@ class BaseTransport(BaseEstimator): transp_Xs = np.dot(transp, self.Xt) else: # perform out of sample mapping - print("Warning: out of sample mapping not yet implemented") - print("input data will be returned") - transp_Xs = Xs + + # get the nearest neighbor in the source domain + D0 = dist(Xs, self.Xs) + idx = np.argmin(D0, axis=1) + + # transport the source samples + transp = self.Coupling_ / np.sum(self.Coupling_, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_ = np.dot(transp, self.Xt) + + # define the transported points + transp_Xs = transp_Xs_[idx, :] + Xs - self.Xs[idx, :] return transp_Xs @@ -1202,9 +1211,17 @@ class BaseTransport(BaseEstimator): transp_Xt = np.dot(transp_, self.Xs) else: # perform out of sample mapping - print("Warning: out of sample mapping not yet implemented") - print("input data will be returned") - transp_Xt = Xt + + D0 = dist(Xt, self.Xt) + idx = np.argmin(D0, axis=1) + + # transport the target samples + transp_ = self.Coupling_.T / np.sum(self.Coupling_, 0)[:, None] + transp_[~ np.isfinite(transp_)] = 0 + transp_Xt_ = np.dot(transp_, self.Xs) + + # define the transported points + transp_Xt = transp_Xt_[idx, :] + Xt - self.Xt[idx, :] return transp_Xt diff --git a/test/test_da.py b/test/test_da.py index ecd2a3a..aed9f61 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -45,8 +45,8 @@ def test_sinkhorn_lpl1_transport_class(): Xs_new, _ = get_data_classif('3gauss', ns + 1) transp_Xs_new = clf.transform(Xs_new) - # check that the oos method is not working - assert_equal(transp_Xs_new, Xs_new) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) # test inverse transform transp_Xt = clf.inverse_transform(Xt=Xt) @@ -55,8 +55,8 @@ def test_sinkhorn_lpl1_transport_class(): Xt_new, _ = get_data_classif('3gauss2', nt + 1) transp_Xt_new = clf.inverse_transform(Xt=Xt_new) - # check that the oos method is not working and returns the input data - assert_equal(transp_Xt_new, Xt_new) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) # test fit_transform transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) @@ -108,8 +108,8 @@ def test_sinkhorn_l1l2_transport_class(): Xs_new, _ = get_data_classif('3gauss', ns + 1) transp_Xs_new = clf.transform(Xs_new) - # check that the oos method is not working - assert_equal(transp_Xs_new, Xs_new) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) # test inverse transform transp_Xt = clf.inverse_transform(Xt=Xt) @@ -118,8 +118,8 @@ def test_sinkhorn_l1l2_transport_class(): Xt_new, _ = get_data_classif('3gauss2', nt + 1) transp_Xt_new = clf.inverse_transform(Xt=Xt_new) - # check that the oos method is not working and returns the input data - assert_equal(transp_Xt_new, Xt_new) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) # test fit_transform transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) @@ -171,8 +171,8 @@ def test_sinkhorn_transport_class(): Xs_new, _ = get_data_classif('3gauss', ns + 1) transp_Xs_new = clf.transform(Xs_new) - # check that the oos method is not working - assert_equal(transp_Xs_new, Xs_new) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) # test inverse transform transp_Xt = clf.inverse_transform(Xt=Xt) @@ -181,8 +181,8 @@ def test_sinkhorn_transport_class(): Xt_new, _ = get_data_classif('3gauss2', nt + 1) transp_Xt_new = clf.inverse_transform(Xt=Xt_new) - # check that the oos method is not working and returns the input data - assert_equal(transp_Xt_new, Xt_new) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) # test fit_transform transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) @@ -234,8 +234,8 @@ def test_emd_transport_class(): Xs_new, _ = get_data_classif('3gauss', ns + 1) transp_Xs_new = clf.transform(Xs_new) - # check that the oos method is not working - assert_equal(transp_Xs_new, Xs_new) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) # test inverse transform transp_Xt = clf.inverse_transform(Xt=Xt) @@ -244,8 +244,8 @@ def test_emd_transport_class(): Xt_new, _ = get_data_classif('3gauss2', nt + 1) transp_Xt_new = clf.inverse_transform(Xt=Xt_new) - # check that the oos method is not working and returns the input data - assert_equal(transp_Xt_new, Xt_new) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) # test fit_transform transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) |