From 326d163db029515c338a963978a5d95948f78c29 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Wed, 23 Aug 2017 14:11:13 +0200 Subject: test functions for MappingTransport Class --- test/test_da.py | 117 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 112 insertions(+), 5 deletions(-) (limited to 'test/test_da.py') diff --git a/test/test_da.py b/test/test_da.py index 196f4c4..162f681 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -264,6 +264,112 @@ def test_emd_transport_class(): assert n_unsup != n_semisup, "semisupervised mode not working" +def test_mapping_transport_class(): + """test_mapping_transport + """ + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + Xs_new, _ = get_data_classif('3gauss', ns + 1) + + ########################################################################## + # kernel == linear mapping tests + ########################################################################## + + # check computation and dimensions if bias == False + clf = ot.da.MappingTransport(kernel="linear", bias=False) + clf.fit(Xs=Xs, Xt=Xt) + + assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1]))) + + # test margin constraints + mu_s = unif(ns) + mu_t = unif(nt) + assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = clf.transform(Xs=Xs) + assert_equal(transp_Xs.shape, Xs.shape) + + transp_Xs_new = clf.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) + + # check computation and dimensions if bias == True + clf = ot.da.MappingTransport(kernel="linear", bias=True) + clf.fit(Xs=Xs, Xt=Xt) + assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert_equal(clf.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1]))) + + # test margin constraints + mu_s = unif(ns) + mu_t = unif(nt) + assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = clf.transform(Xs=Xs) + assert_equal(transp_Xs.shape, Xs.shape) + + transp_Xs_new = clf.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) + + ########################################################################## + # kernel == gaussian mapping tests + ########################################################################## + + # check computation and dimensions if bias == False + clf = ot.da.MappingTransport(kernel="gaussian", bias=False) + clf.fit(Xs=Xs, Xt=Xt) + + assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert_equal(clf.mapping_.shape, ((Xs.shape[0], Xt.shape[1]))) + + # test margin constraints + mu_s = unif(ns) + mu_t = unif(nt) + assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = clf.transform(Xs=Xs) + assert_equal(transp_Xs.shape, Xs.shape) + + transp_Xs_new = clf.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) + + # check computation and dimensions if bias == True + clf = ot.da.MappingTransport(kernel="gaussian", bias=True) + clf.fit(Xs=Xs, Xt=Xt) + assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert_equal(clf.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1]))) + + # test margin constraints + mu_s = unif(ns) + mu_t = unif(nt) + assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = clf.transform(Xs=Xs) + assert_equal(transp_Xs.shape, Xs.shape) + + transp_Xs_new = clf.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) + + def test_otda(): n_samples = 150 # nb samples @@ -326,9 +432,10 @@ def test_otda(): da_emd.predict(xs) # interpolation of source samples -# if __name__ == "__main__": +if __name__ == "__main__": -# test_sinkhorn_transport_class() -# test_emd_transport_class() -# test_sinkhorn_l1l2_transport_class() -# test_sinkhorn_lpl1_transport_class() + # test_sinkhorn_transport_class() + # test_emd_transport_class() + # test_sinkhorn_l1l2_transport_class() + # test_sinkhorn_lpl1_transport_class() + test_mapping_transport_class() -- cgit v1.2.3