summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-04 12:04:04 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:09:13 +0200
commit4e562a1ce24119b8c9c1efb9d078762904c5d78a (patch)
tree0487c2d8c5fc3ef55524a988314ed79144eaf45d /test
parent2005a09548a6f6d42cd9aafadbb4583e4029936c (diff)
semi supervised mode supported
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 68d1958..497a8ee 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -62,6 +62,19 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_sinkhorn_l1l2_transport_class():
"""test_sinkhorn_transport
@@ -112,6 +125,19 @@ def test_sinkhorn_l1l2_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_sinkhorn_transport_class():
"""test_sinkhorn_transport
@@ -162,6 +188,19 @@ def test_sinkhorn_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_emd_transport_class():
"""test_sinkhorn_transport
@@ -212,6 +251,19 @@ def test_emd_transport_class():
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(clf.Cost)
+
+ # test semi supervised mode
+ clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.Cost)
+
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
def test_otda():