summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-09-05 10:00:01 +0200
committerSlasnista <stan.chambon@gmail.com>2017-09-05 10:00:01 +0200
commit49c100de34583329058b39d414d2aa49b7fd15bf (patch)
treed435b7814c86d710610ca291580ed3d3206edfce /test/test_da.py
parent8e4a7930cf1ff80edeb30021acaf7337a02d18a5 (diff)
test semi supervised mode ok written for all class | need different tolerance for EMDTransport
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py29
1 files changed, 18 insertions, 11 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 9fc42a3..3602db9 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -64,11 +64,11 @@ def test_sinkhorn_lpl1_transport_class():
assert_equal(transp_Xs.shape, Xs.shape)
# test unsupervised vs semi-supervised mode
- otda_unsup = ot.da.SinkhornTransport()
- otda_unsup.fit(Xs=Xs, Xt=Xt)
+ otda_unsup = ot.da.SinkhornLpl1Transport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
n_unsup = np.sum(otda_unsup.cost_)
- otda_semi = ot.da.SinkhornTransport()
+ otda_semi = ot.da.SinkhornLpl1Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = np.sum(otda_semi.cost_)
@@ -136,11 +136,11 @@ def test_sinkhorn_l1l2_transport_class():
assert_equal(transp_Xs.shape, Xs.shape)
# test unsupervised vs semi-supervised mode
- otda_unsup = ot.da.SinkhornTransport()
- otda_unsup.fit(Xs=Xs, Xt=Xt)
+ otda_unsup = ot.da.SinkhornL1l2Transport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
n_unsup = np.sum(otda_unsup.cost_)
- otda_semi = ot.da.SinkhornTransport()
+ otda_semi = ot.da.SinkhornL1l2Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = np.sum(otda_semi.cost_)
@@ -152,7 +152,9 @@ def test_sinkhorn_l1l2_transport_class():
# and labeled target samples
mass_semi = np.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
- assert mass_semi == 0, "semisupervised mode not working"
+ mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
+ assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ rtol=1e-9, atol=1e-9)
# check everything runs well with log=True
otda = ot.da.SinkhornL1l2Transport(log=True)
@@ -289,11 +291,11 @@ def test_emd_transport_class():
assert_equal(transp_Xs.shape, Xs.shape)
# test unsupervised vs semi-supervised mode
- otda_unsup = ot.da.SinkhornTransport()
- otda_unsup.fit(Xs=Xs, Xt=Xt)
+ otda_unsup = ot.da.EMDTransport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
n_unsup = np.sum(otda_unsup.cost_)
- otda_semi = ot.da.SinkhornTransport()
+ otda_semi = ot.da.EMDTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = np.sum(otda_semi.cost_)
@@ -305,7 +307,11 @@ def test_emd_transport_class():
# and labeled target samples
mass_semi = np.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
- assert mass_semi == 0, "semisupervised mode not working"
+ mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
+
+ # we need to use a small tolerance here, otherwise the test breaks
+ assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ rtol=1e-2, atol=1e-2)
def test_mapping_transport_class():
@@ -491,3 +497,4 @@ def test_otda():
# test_sinkhorn_l1l2_transport_class()
# test_sinkhorn_lpl1_transport_class()
# test_mapping_transport_class()
+