From 2d4d0b46f88c66ebc5502c840703ba6ce8910376 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Fri, 25 Aug 2017 10:29:41 +0200 Subject: solving log issues to avoid errors and adding further tests --- test/test_da.py | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) (limited to 'test') diff --git a/test/test_da.py b/test/test_da.py index 9578b3d..104a798 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -26,6 +26,8 @@ def test_sinkhorn_lpl1_transport_class(): # test its computed clf.fit(Xs=Xs, ys=ys, Xt=Xt) + assert hasattr(clf, "cost_") + assert hasattr(clf, "coupling_") # test dimensions of coupling assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) @@ -89,6 +91,9 @@ def test_sinkhorn_l1l2_transport_class(): # test its computed clf.fit(Xs=Xs, ys=ys, Xt=Xt) + assert hasattr(clf, "cost_") + assert hasattr(clf, "coupling_") + assert hasattr(clf, "log_") # test dimensions of coupling assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) @@ -137,6 +142,11 @@ def test_sinkhorn_l1l2_transport_class(): assert n_unsup != n_semisup, "semisupervised mode not working" + # check everything runs well with log=True + clf = ot.da.SinkhornL1l2Transport(log=True) + clf.fit(Xs=Xs, ys=ys, Xt=Xt) + assert len(clf.log_.keys()) != 0 + def test_sinkhorn_transport_class(): """test_sinkhorn_transport @@ -152,6 +162,9 @@ def test_sinkhorn_transport_class(): # test its computed clf.fit(Xs=Xs, Xt=Xt) + assert hasattr(clf, "cost_") + assert hasattr(clf, "coupling_") + assert hasattr(clf, "log_") # test dimensions of coupling assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) @@ -200,6 +213,11 @@ def test_sinkhorn_transport_class(): assert n_unsup != n_semisup, "semisupervised mode not working" + # check everything runs well with log=True + clf = ot.da.SinkhornTransport(log=True) + clf.fit(Xs=Xs, ys=ys, Xt=Xt) + assert len(clf.log_.keys()) != 0 + def test_emd_transport_class(): """test_sinkhorn_transport @@ -215,6 +233,8 @@ def test_emd_transport_class(): # test its computed clf.fit(Xs=Xs, Xt=Xt) + assert hasattr(clf, "cost_") + assert hasattr(clf, "coupling_") # test dimensions of coupling assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) @@ -282,6 +302,9 @@ def test_mapping_transport_class(): # check computation and dimensions if bias == False clf = ot.da.MappingTransport(kernel="linear", bias=False) clf.fit(Xs=Xs, Xt=Xt) + assert hasattr(clf, "coupling_") + assert hasattr(clf, "mapping_") + assert hasattr(clf, "log_") assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1]))) @@ -369,6 +392,11 @@ def test_mapping_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) + # check everything runs well with log=True + clf = ot.da.MappingTransport(kernel="gaussian", log=True) + clf.fit(Xs=Xs, Xt=Xt) + assert len(clf.log_.keys()) != 0 + def test_otda(): @@ -434,9 +462,8 @@ def test_otda(): # if __name__ == "__main__": - # test_otda() - # test_sinkhorn_transport_class() - # test_emd_transport_class() - # test_sinkhorn_l1l2_transport_class() - # test_sinkhorn_lpl1_transport_class() - # test_mapping_transport_class() +# test_sinkhorn_transport_class() +# test_emd_transport_class() +# test_sinkhorn_l1l2_transport_class() +# test_sinkhorn_lpl1_transport_class() +# test_mapping_transport_class() -- cgit v1.2.3