From 4e562a1ce24119b8c9c1efb9d078762904c5d78a Mon Sep 17 00:00:00 2001 From: Slasnista Date: Fri, 4 Aug 2017 12:04:04 +0200 Subject: semi supervised mode supported --- ot/da.py | 21 +++++++++++++++++++-- test/test_da.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/ot/da.py b/ot/da.py index 6100d15..8294e8d 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1089,8 +1089,25 @@ class BaseTransport(BaseEstimator): self.Cost = dist(Xs, Xt, metric=self.metric) if self.mode == "semisupervised": - print("TODO: modify cost matrix accordingly") - pass + + if (ys is not None) and (yt is not None): + + # assumes labeled source samples occupy the first rows + # and labeled target samples occupy the first columns + classes = np.unique(ys) + for c in classes: + ids = np.where(ys == c) + idt = np.where(yt == c) + + # all the coefficients corresponding to a source sample + # and a target sample with the same label gets a 0 + # transport cost + for j in idt[0]: + self.Cost[ids[0], j] = 0 + else: + print("Warning: using unsupervised mode\ + \nto use semisupervised mode, please provide ys and yt") + pass # distribution estimation self.mu_s = self.distribution_estimation(Xs) diff --git a/test/test_da.py b/test/test_da.py index 68d1958..497a8ee 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -62,6 +62,19 @@ def test_sinkhorn_lpl1_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_sinkhorn_l1l2_transport_class(): """test_sinkhorn_transport @@ -112,6 +125,19 @@ def test_sinkhorn_l1l2_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_sinkhorn_transport_class(): """test_sinkhorn_transport @@ -162,6 +188,19 @@ def test_sinkhorn_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_emd_transport_class(): """test_sinkhorn_transport @@ -212,6 +251,19 @@ def test_emd_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_otda(): -- cgit v1.2.3