diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-08-04 12:04:04 +0200 |
---|---|---|
committer | Slasnista <stan.chambon@gmail.com> | 2017-08-04 12:04:04 +0200 |
commit | 0b005906f9d78adbf4d52d2ea9610eb3fde96a7c (patch) | |
tree | 0487c2d8c5fc3ef55524a988314ed79144eaf45d /test/test_da.py | |
parent | 727077ad7db503955aea0751abf9f361f1d82af7 (diff) |
semi supervised mode supported
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py index 68d1958..497a8ee 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -62,6 +62,19 @@ def test_sinkhorn_lpl1_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_sinkhorn_l1l2_transport_class(): """test_sinkhorn_transport @@ -112,6 +125,19 @@ def test_sinkhorn_l1l2_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_sinkhorn_transport_class(): """test_sinkhorn_transport @@ -162,6 +188,19 @@ def test_sinkhorn_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_emd_transport_class(): """test_sinkhorn_transport @@ -212,6 +251,19 @@ def test_emd_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_otda(): |