summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py21
-rw-r--r--test/test_da.py52
2 files changed, 71 insertions, 2 deletions
diff --git a/ot/da.py b/ot/da.py
index 6100d15..8294e8d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1089,8 +1089,25 @@ class BaseTransport(BaseEstimator):
self.Cost = dist(Xs, Xt, metric=self.metric)
if self.mode == "semisupervised":
- print("TODO: modify cost matrix accordingly")
- pass
+
+ if (ys is not None) and (yt is not None):
+
+ # assumes labeled source samples occupy the first rows
+ # and labeled target samples occupy the first columns
+ classes = np.unique(ys)
+ for c in classes:
+ ids = np.where(ys == c)
+ idt = np.where(yt == c)
+
+ # all the coefficients corresponding to a source sample
+ # and a target sample with the same label gets a 0
+ # transport cost
+ for j in idt[0]:
+ self.Cost[ids[0], j] = 0
+ else:
+ print("Warning: using unsupervised mode\
+ \nto use semisupervised mode, please provide ys and yt")
+ pass
# distribution estimation
self.mu_s = self.distribution_estimation(Xs)
diff --git a/test/test_da.py b/test/test_da.py
index 68d1958..497a8ee 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -62,6 +62,19 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_sinkhorn_l1l2_transport_class():
"""test_sinkhorn_transport
@@ -112,6 +125,19 @@ def test_sinkhorn_l1l2_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_sinkhorn_transport_class():
"""test_sinkhorn_transport
@@ -162,6 +188,19 @@ def test_sinkhorn_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_emd_transport_class():
"""test_sinkhorn_transport
@@ -212,6 +251,19 @@ def test_emd_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_otda():