From 49c100de34583329058b39d414d2aa49b7fd15bf Mon Sep 17 00:00:00 2001 From: Slasnista Date: Tue, 5 Sep 2017 10:00:01 +0200 Subject: test semi supervised mode ok written for all class | need different tolerance for EMDTransport MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_da.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index 9fc42a3..3602db9 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -64,11 +64,11 @@ def test_sinkhorn_lpl1_transport_class(): assert_equal(transp_Xs.shape, Xs.shape) # test unsupervised vs semi-supervised mode - otda_unsup = ot.da.SinkhornTransport() - otda_unsup.fit(Xs=Xs, Xt=Xt) + otda_unsup = ot.da.SinkhornLpl1Transport() + otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) n_unsup = np.sum(otda_unsup.cost_) - otda_semi = ot.da.SinkhornTransport() + otda_semi = ot.da.SinkhornLpl1Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = np.sum(otda_semi.cost_) @@ -136,11 +136,11 @@ def test_sinkhorn_l1l2_transport_class(): assert_equal(transp_Xs.shape, Xs.shape) # test unsupervised vs semi-supervised mode - otda_unsup = ot.da.SinkhornTransport() - otda_unsup.fit(Xs=Xs, Xt=Xt) + otda_unsup = ot.da.SinkhornL1l2Transport() + otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) n_unsup = np.sum(otda_unsup.cost_) - otda_semi = ot.da.SinkhornTransport() + otda_semi = ot.da.SinkhornL1l2Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = np.sum(otda_semi.cost_) @@ -152,7 +152,9 @@ def test_sinkhorn_l1l2_transport_class(): # and labeled target samples mass_semi = np.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) - assert mass_semi == 0, "semisupervised mode not working" + mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] + assert_allclose(mass_semi, np.zeros_like(mass_semi), + rtol=1e-9, atol=1e-9) # check everything runs well with log=True otda = ot.da.SinkhornL1l2Transport(log=True) @@ -289,11 +291,11 @@ def test_emd_transport_class(): assert_equal(transp_Xs.shape, Xs.shape) # test unsupervised vs semi-supervised mode - otda_unsup = ot.da.SinkhornTransport() - otda_unsup.fit(Xs=Xs, Xt=Xt) + otda_unsup = ot.da.EMDTransport() + otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) n_unsup = np.sum(otda_unsup.cost_) - otda_semi = ot.da.SinkhornTransport() + otda_semi = ot.da.EMDTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = np.sum(otda_semi.cost_) @@ -305,7 +307,11 @@ def test_emd_transport_class(): # and labeled target samples mass_semi = np.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) - assert mass_semi == 0, "semisupervised mode not working" + mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] + + # we need to use a small tolerance here, otherwise the test breaks + assert_allclose(mass_semi, np.zeros_like(mass_semi), + rtol=1e-2, atol=1e-2) def test_mapping_transport_class(): @@ -491,3 +497,4 @@ def test_otda(): # test_sinkhorn_l1l2_transport_class() # test_sinkhorn_lpl1_transport_class() # test_mapping_transport_class() + -- cgit v1.2.3