diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-15 11:12:23 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-15 11:12:23 +0200 |
commit | 749378a50abd763c87f5cf24a4b2e0dff2a6ec6a (patch) | |
tree | e8aeb51b7bcf3b48fead20fa44eee154da5b8d05 /test/test_da.py | |
parent | 1a4c264cc9b2cb0bb89840ee9175177e86eef3ef (diff) |
fix soft labels, remove gammas from jcpot
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/test/test_da.py b/test/test_da.py index d96046d..70296bf 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -68,10 +68,12 @@ def test_sinkhorn_lpl1_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() @@ -140,10 +142,12 @@ def test_sinkhorn_l1l2_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -229,10 +233,12 @@ def test_sinkhorn_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -298,10 +304,12 @@ def test_unbalanced_sinkhorn_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) Xs_new, _ = make_data_classif('3gauss', ns + 1) transp_Xs_new = otda.transform(Xs_new) @@ -388,10 +396,12 @@ def test_emd_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -645,10 +655,12 @@ def test_jcpot_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) - [assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)] + [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)] + [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)] def test_jcpot_barycenter(): |