summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-08 16:05:01 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-08 16:05:01 +0200
commit0b402fd7c7e07043afd3a9df9d75bc424730b06f (patch)
tree9d41a03bb510bc8a6105c9e4d102ad5a274f3a9b /test
parentbc51793333994a1bf6263c9e9c111d754172fc82 (diff)
add label prop + inverse
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py
index c54dab7..4eb6de0 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -65,6 +65,14 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornLpl1Transport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
@@ -129,6 +137,14 @@ def test_sinkhorn_l1l2_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -210,6 +226,14 @@ def test_sinkhorn_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -271,6 +295,14 @@ def test_unbalanced_sinkhorn_transport_class():
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
Xs_new, _ = make_data_classif('3gauss', ns + 1)
transp_Xs_new = otda.transform(Xs_new)
@@ -353,6 +385,14 @@ def test_emd_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -602,6 +642,13 @@ def test_jcpot_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ [assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)]
def test_jcpot_barycenter():
"""test_jcpot_barycenter