summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py171
-rw-r--r--test/test_da.py47
2 files changed, 214 insertions, 4 deletions
diff --git a/ot/da.py b/ot/da.py
index 3a458eb..29b0a8b 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -943,6 +943,46 @@ class BaseTransport(BaseEstimator):
return transp_Xs
+ def transform_labels(self, ys=None):
+ """Propagate source labels ys to obtain estimated target labels
+
+ Parameters
+ ----------
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+
+ Returns
+ -------
+ transp_ys : array-like, shape (n_target_samples,)
+ Estimated target labels.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(ys=ys):
+
+ classes = np.unique(ys)
+ n = len(classes)
+ D1 = np.zeros((n, len(ys)))
+
+ # perform label propagation
+ transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ if np.min(classes) != 0:
+ ys = ys - np.min(classes)
+ classes = np.unique(ys)
+
+ for c in classes:
+ D1[int(c), ys == c] = 1
+
+ # compute transported samples
+ transp_ys = np.dot(D1, transp)
+
+ return np.argmax(transp_ys,axis=0)
+
+
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
"""Transports target samples Xt onto target samples Xs
@@ -1010,6 +1050,44 @@ class BaseTransport(BaseEstimator):
return transp_Xt
+ def inverse_transform_labels(self, yt=None):
+ """Propagate target labels yt to obtain estimated source labels ys
+
+ Parameters
+ ----------
+ yt : array-like, shape (n_target_samples,)
+
+ Returns
+ -------
+ transp_ys : array-like, shape (n_source_samples,)
+ Estimated source labels.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(yt=yt):
+
+ classes = np.unique(yt)
+ n = len(classes)
+ D1 = np.zeros((n, len(yt)))
+
+ # perform label propagation
+ transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ if np.min(classes) != 0:
+ yt = yt - np.min(classes)
+ classes = np.unique(yt)
+
+ for c in classes:
+ D1[int(c), yt == c] = 1
+
+ # compute transported samples
+ transp_ys = np.dot(D1, transp.T)
+
+ return np.argmax(transp_ys,axis=0)
+
class LinearTransport(BaseTransport):
@@ -2017,10 +2095,10 @@ class JCPOTTransport(BaseTransport):
Parameters
----------
- Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
- ys : array-like, shape (n_source_samples,)
- The class labels
+ Xs : list of K array-like objects, shape K x (nk_source_samples, n_features)
+ A list of the training input samples.
+ ys : list of K array-like objects, shape K x (nk_source_samples,)
+ A list of the class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -2083,3 +2161,88 @@ class JCPOTTransport(BaseTransport):
transp_Xs = np.concatenate(transp_Xs, axis=0)
return transp_Xs
+
+ def transform_labels(self, ys=None):
+ """Propagate source labels ys to obtain target labels
+
+ Parameters
+ ----------
+ ys : list of K array-like objects, shape K x (nk_source_samples,)
+ A list of the class labels
+
+ Returns
+ -------
+ yt : array-like, shape (n_target_samples,)
+ Estimated target labels.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(ys=ys):
+ yt = np.zeros((len(np.unique(np.concatenate(ys))),self.xt_.shape[0]))
+ for i in range(len(ys)):
+ classes = np.unique(ys[i])
+ n = len(classes)
+ ns = len(ys[i])
+
+ # perform label propagation
+ transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ if self.log:
+ D1 = self.log_['D1'][i]
+ else:
+ D1 = np.zeros((n, ns))
+
+ if np.min(classes) != 0:
+ ys = ys - np.min(classes)
+ classes = np.unique(ys)
+
+ for c in classes:
+ D1[int(c), ys == c] = 1
+ # compute transported samples
+ yt = yt + np.dot(D1, transp)/len(ys)
+
+ return np.argmax(yt,axis=0)
+
+ def inverse_transform_labels(self, yt=None):
+ """Propagate source labels ys to obtain target labels
+
+ Parameters
+ ----------
+ yt : array-like, shape (n_source_samples,)
+ The target class labels
+
+ Returns
+ -------
+ transp_ys : list of K array-like objects, shape K x (nk_source_samples,)
+ A list of estimated source labels
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(yt=yt):
+ transp_ys = []
+ classes = np.unique(yt)
+ n = len(classes)
+ D1 = np.zeros((n, len(yt)))
+
+ if np.min(classes) != 0:
+ yt = yt - np.min(classes)
+ classes = np.unique(yt)
+
+ for c in classes:
+ D1[int(c), yt == c] = 1
+
+ for i in range(len(self.xs_)):
+
+ # perform label propagation
+ transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ # compute transported labels
+ transp_ys.append(np.argmax(np.dot(D1, transp.T),axis=0))
+
+ return transp_ys
diff --git a/test/test_da.py b/test/test_da.py
index c54dab7..4eb6de0 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -65,6 +65,14 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornLpl1Transport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
@@ -129,6 +137,14 @@ def test_sinkhorn_l1l2_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -210,6 +226,14 @@ def test_sinkhorn_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -271,6 +295,14 @@ def test_unbalanced_sinkhorn_transport_class():
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
Xs_new, _ = make_data_classif('3gauss', ns + 1)
transp_Xs_new = otda.transform(Xs_new)
@@ -353,6 +385,14 @@ def test_emd_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -602,6 +642,13 @@ def test_jcpot_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ [assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)]
def test_jcpot_barycenter():
"""test_jcpot_barycenter