From 0b005906f9d78adbf4d52d2ea9610eb3fde96a7c Mon Sep 17 00:00:00 2001 From: Slasnista Date: Fri, 4 Aug 2017 12:04:04 +0200 Subject: semi supervised mode supported --- test/test_da.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) (limited to 'test/test_da.py') diff --git a/test/test_da.py b/test/test_da.py index 68d1958..497a8ee 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -62,6 +62,19 @@ def test_sinkhorn_lpl1_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_sinkhorn_l1l2_transport_class(): """test_sinkhorn_transport @@ -112,6 +125,19 @@ def test_sinkhorn_l1l2_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_sinkhorn_transport_class(): """test_sinkhorn_transport @@ -162,6 +188,19 @@ def test_sinkhorn_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_emd_transport_class(): """test_sinkhorn_transport @@ -212,6 +251,19 @@ def test_emd_transport_class(): transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(clf.Cost) + + # test semi supervised mode + clf = ot.da.SinkhornTransport(mode="semisupervised") + clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(clf.Cost) + + assert n_unsup != n_semisup, "semisupervised mode not working" + def test_otda(): -- cgit v1.2.3