summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-04 14:55:54 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-04 14:55:54 +0200
commit738bfb1c560ff4e349f5083fc1f81a54e4be4980 (patch)
treead079d5f2f94985f504140973d4d2418c21245d4 /test/test_da.py
parent778f4f76d7f162e7630c9ba5369a0e389e18433c (diff)
out of samples by Ferradans supported for transform and inverse_transform
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/test/test_da.py b/test/test_da.py
index ecd2a3a..aed9f61 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -45,8 +45,8 @@ def test_sinkhorn_lpl1_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -55,8 +55,8 @@ def test_sinkhorn_lpl1_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
@@ -108,8 +108,8 @@ def test_sinkhorn_l1l2_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -118,8 +118,8 @@ def test_sinkhorn_l1l2_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
@@ -171,8 +171,8 @@ def test_sinkhorn_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -181,8 +181,8 @@ def test_sinkhorn_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
@@ -234,8 +234,8 @@ def test_emd_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -244,8 +244,8 @@ def test_emd_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)