From 738bfb1c560ff4e349f5083fc1f81a54e4be4980 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Fri, 4 Aug 2017 14:55:54 +0200 Subject: out of samples by Ferradans supported for transform and inverse_transform --- test/test_da.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) (limited to 'test/test_da.py') diff --git a/test/test_da.py b/test/test_da.py index ecd2a3a..aed9f61 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -45,8 +45,8 @@ def test_sinkhorn_lpl1_transport_class(): Xs_new, _ = get_data_classif('3gauss', ns + 1) transp_Xs_new = clf.transform(Xs_new) - # check that the oos method is not working - assert_equal(transp_Xs_new, Xs_new) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) # test inverse transform transp_Xt = clf.inverse_transform(Xt=Xt) @@ -55,8 +55,8 @@ def test_sinkhorn_lpl1_transport_class(): Xt_new, _ = get_data_classif('3gauss2', nt + 1) transp_Xt_new = clf.inverse_transform(Xt=Xt_new) - # check that the oos method is not working and returns the input data - assert_equal(transp_Xt_new, Xt_new) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) # test fit_transform transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) @@ -108,8 +108,8 @@ def test_sinkhorn_l1l2_transport_class(): Xs_new, _ = get_data_classif('3gauss', ns + 1) transp_Xs_new = clf.transform(Xs_new) - # check that the oos method is not working - assert_equal(transp_Xs_new, Xs_new) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) # test inverse transform transp_Xt = clf.inverse_transform(Xt=Xt) @@ -118,8 +118,8 @@ def test_sinkhorn_l1l2_transport_class(): Xt_new, _ = get_data_classif('3gauss2', nt + 1) transp_Xt_new = clf.inverse_transform(Xt=Xt_new) - # check that the oos method is not working and returns the input data - assert_equal(transp_Xt_new, Xt_new) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) # test fit_transform transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt) @@ -171,8 +171,8 @@ def test_sinkhorn_transport_class(): Xs_new, _ = get_data_classif('3gauss', ns + 1) transp_Xs_new = clf.transform(Xs_new) - # check that the oos method is not working - assert_equal(transp_Xs_new, Xs_new) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) # test inverse transform transp_Xt = clf.inverse_transform(Xt=Xt) @@ -181,8 +181,8 @@ def test_sinkhorn_transport_class(): Xt_new, _ = get_data_classif('3gauss2', nt + 1) transp_Xt_new = clf.inverse_transform(Xt=Xt_new) - # check that the oos method is not working and returns the input data - assert_equal(transp_Xt_new, Xt_new) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) # test fit_transform transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) @@ -234,8 +234,8 @@ def test_emd_transport_class(): Xs_new, _ = get_data_classif('3gauss', ns + 1) transp_Xs_new = clf.transform(Xs_new) - # check that the oos method is not working - assert_equal(transp_Xs_new, Xs_new) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) # test inverse transform transp_Xt = clf.inverse_transform(Xt=Xt) @@ -244,8 +244,8 @@ def test_emd_transport_class(): Xt_new, _ = get_data_classif('3gauss2', nt + 1) transp_Xt_new = clf.inverse_transform(Xt=Xt_new) - # check that the oos method is not working and returns the input data - assert_equal(transp_Xt_new, Xt_new) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) # test fit_transform transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt) -- cgit v1.2.3