summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-15 11:12:23 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-15 11:12:23 +0200
commit749378a50abd763c87f5cf24a4b2e0dff2a6ec6a (patch)
treee8aeb51b7bcf3b48fead20fa44eee154da5b8d05 /test
parent1a4c264cc9b2cb0bb89840ee9175177e86eef3ef (diff)
fix soft labels, remove gammas from jcpot
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/test/test_da.py b/test/test_da.py
index d96046d..70296bf 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -68,10 +68,12 @@ def test_sinkhorn_lpl1_transport_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornLpl1Transport()
@@ -140,10 +142,12 @@ def test_sinkhorn_l1l2_transport_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -229,10 +233,12 @@ def test_sinkhorn_transport_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -298,10 +304,12 @@ def test_unbalanced_sinkhorn_transport_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
Xs_new, _ = make_data_classif('3gauss', ns + 1)
transp_Xs_new = otda.transform(Xs_new)
@@ -388,10 +396,12 @@ def test_emd_transport_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -645,10 +655,12 @@ def test_jcpot_transport_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
- [assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)]
+ [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)]
+ [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)]
def test_jcpot_barycenter():