summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/test/test_da.py b/test/test_da.py
index ecd2a3a..aed9f61 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -45,8 +45,8 @@ def test_sinkhorn_lpl1_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -55,8 +55,8 @@ def test_sinkhorn_lpl1_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
@@ -108,8 +108,8 @@ def test_sinkhorn_l1l2_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -118,8 +118,8 @@ def test_sinkhorn_l1l2_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
@@ -171,8 +171,8 @@ def test_sinkhorn_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -181,8 +181,8 @@ def test_sinkhorn_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
@@ -234,8 +234,8 @@ def test_emd_transport_class():
Xs_new, _ = get_data_classif('3gauss', ns + 1)
transp_Xs_new = clf.transform(Xs_new)
- # check that the oos method is not working
- assert_equal(transp_Xs_new, Xs_new)
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -244,8 +244,8 @@ def test_emd_transport_class():
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
- # check that the oos method is not working and returns the input data
- assert_equal(transp_Xt_new, Xt_new)
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)