diff options
-rw-r--r-- | ot/da.py | 171 | ||||
-rw-r--r-- | test/test_da.py | 47 |
2 files changed, 214 insertions, 4 deletions
@@ -943,6 +943,46 @@ class BaseTransport(BaseEstimator): return transp_Xs + def transform_labels(self, ys=None): + """Propagate source labels ys to obtain estimated target labels + + Parameters + ---------- + ys : array-like, shape (n_source_samples,) + The class labels + + Returns + ------- + transp_ys : array-like, shape (n_target_samples,) + Estimated target labels. + """ + + # check the necessary inputs parameters are here + if check_params(ys=ys): + + classes = np.unique(ys) + n = len(classes) + D1 = np.zeros((n, len(ys))) + + # perform label propagation + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + if np.min(classes) != 0: + ys = ys - np.min(classes) + classes = np.unique(ys) + + for c in classes: + D1[int(c), ys == c] = 1 + + # compute transported samples + transp_ys = np.dot(D1, transp) + + return np.argmax(transp_ys,axis=0) + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports target samples Xt onto target samples Xs @@ -1010,6 +1050,44 @@ class BaseTransport(BaseEstimator): return transp_Xt + def inverse_transform_labels(self, yt=None): + """Propagate target labels yt to obtain estimated source labels ys + + Parameters + ---------- + yt : array-like, shape (n_target_samples,) + + Returns + ------- + transp_ys : array-like, shape (n_source_samples,) + Estimated source labels. + """ + + # check the necessary inputs parameters are here + if check_params(yt=yt): + + classes = np.unique(yt) + n = len(classes) + D1 = np.zeros((n, len(yt))) + + # perform label propagation + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + if np.min(classes) != 0: + yt = yt - np.min(classes) + classes = np.unique(yt) + + for c in classes: + D1[int(c), yt == c] = 1 + + # compute transported samples + transp_ys = np.dot(D1, transp.T) + + return np.argmax(transp_ys,axis=0) + class LinearTransport(BaseTransport): @@ -2017,10 +2095,10 @@ class JCPOTTransport(BaseTransport): Parameters ---------- - Xs : array-like, shape (n_source_samples, n_features) - The training input samples. - ys : array-like, shape (n_source_samples,) - The class labels + Xs : list of K array-like objects, shape K x (nk_source_samples, n_features) + A list of the training input samples. + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) @@ -2083,3 +2161,88 @@ class JCPOTTransport(BaseTransport): transp_Xs = np.concatenate(transp_Xs, axis=0) return transp_Xs + + def transform_labels(self, ys=None): + """Propagate source labels ys to obtain target labels + + Parameters + ---------- + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels + + Returns + ------- + yt : array-like, shape (n_target_samples,) + Estimated target labels. + """ + + # check the necessary inputs parameters are here + if check_params(ys=ys): + yt = np.zeros((len(np.unique(np.concatenate(ys))),self.xt_.shape[0])) + for i in range(len(ys)): + classes = np.unique(ys[i]) + n = len(classes) + ns = len(ys[i]) + + # perform label propagation + transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + if self.log: + D1 = self.log_['D1'][i] + else: + D1 = np.zeros((n, ns)) + + if np.min(classes) != 0: + ys = ys - np.min(classes) + classes = np.unique(ys) + + for c in classes: + D1[int(c), ys == c] = 1 + # compute transported samples + yt = yt + np.dot(D1, transp)/len(ys) + + return np.argmax(yt,axis=0) + + def inverse_transform_labels(self, yt=None): + """Propagate source labels ys to obtain target labels + + Parameters + ---------- + yt : array-like, shape (n_source_samples,) + The target class labels + + Returns + ------- + transp_ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of estimated source labels + """ + + # check the necessary inputs parameters are here + if check_params(yt=yt): + transp_ys = [] + classes = np.unique(yt) + n = len(classes) + D1 = np.zeros((n, len(yt))) + + if np.min(classes) != 0: + yt = yt - np.min(classes) + classes = np.unique(yt) + + for c in classes: + D1[int(c), yt == c] = 1 + + for i in range(len(self.xs_)): + + # perform label propagation + transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported labels + transp_ys.append(np.argmax(np.dot(D1, transp.T),axis=0)) + + return transp_ys diff --git a/test/test_da.py b/test/test_da.py index c54dab7..4eb6de0 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -65,6 +65,14 @@ def test_sinkhorn_lpl1_transport_class(): transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) @@ -129,6 +137,14 @@ def test_sinkhorn_l1l2_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -210,6 +226,14 @@ def test_sinkhorn_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -271,6 +295,14 @@ def test_unbalanced_sinkhorn_transport_class(): transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + Xs_new, _ = make_data_classif('3gauss', ns + 1) transp_Xs_new = otda.transform(Xs_new) @@ -353,6 +385,14 @@ def test_emd_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -602,6 +642,13 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + [assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)] def test_jcpot_barycenter(): """test_jcpot_barycenter |