summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-04 12:04:04 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-04 12:04:04 +0200
commit0b005906f9d78adbf4d52d2ea9610eb3fde96a7c (patch)
tree0487c2d8c5fc3ef55524a988314ed79144eaf45d /test/test_da.py
parent727077ad7db503955aea0751abf9f361f1d82af7 (diff)
semi supervised mode supported
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 68d1958..497a8ee 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -62,6 +62,19 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_sinkhorn_l1l2_transport_class():
"""test_sinkhorn_transport
@@ -112,6 +125,19 @@ def test_sinkhorn_l1l2_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_sinkhorn_transport_class():
"""test_sinkhorn_transport
@@ -162,6 +188,19 @@ def test_sinkhorn_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_emd_transport_class():
"""test_sinkhorn_transport
@@ -212,6 +251,19 @@ def test_emd_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_otda():