summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-25 10:29:41 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-25 10:29:41 +0200
commit2d4d0b46f88c66ebc5502c840703ba6ce8910376 (patch)
tree5f0ead47c2d0bb769edca3e0f9e8f155ea73ffd0 /test/test_da.py
parent09302239b3e4e1a90c1a4e2d7a85b0af86b01365 (diff)
solving log issues to avoid errors and adding further tests
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py39
1 files changed, 33 insertions, 6 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 9578b3d..104a798 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -26,6 +26,8 @@ def test_sinkhorn_lpl1_transport_class():
# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -89,6 +91,9 @@ def test_sinkhorn_l1l2_transport_class():
# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "log_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -137,6 +142,11 @@ def test_sinkhorn_l1l2_transport_class():
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check everything runs well with log=True
+ clf = ot.da.SinkhornL1l2Transport(log=True)
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_sinkhorn_transport_class():
"""test_sinkhorn_transport
@@ -152,6 +162,9 @@ def test_sinkhorn_transport_class():
# test its computed
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "log_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -200,6 +213,11 @@ def test_sinkhorn_transport_class():
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check everything runs well with log=True
+ clf = ot.da.SinkhornTransport(log=True)
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_emd_transport_class():
"""test_sinkhorn_transport
@@ -215,6 +233,8 @@ def test_emd_transport_class():
# test its computed
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -282,6 +302,9 @@ def test_mapping_transport_class():
# check computation and dimensions if bias == False
clf = ot.da.MappingTransport(kernel="linear", bias=False)
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "mapping_")
+ assert hasattr(clf, "log_")
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
@@ -369,6 +392,11 @@ def test_mapping_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
+ # check everything runs well with log=True
+ clf = ot.da.MappingTransport(kernel="gaussian", log=True)
+ clf.fit(Xs=Xs, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_otda():
@@ -434,9 +462,8 @@ def test_otda():
# if __name__ == "__main__":
- # test_otda()
- # test_sinkhorn_transport_class()
- # test_emd_transport_class()
- # test_sinkhorn_l1l2_transport_class()
- # test_sinkhorn_lpl1_transport_class()
- # test_mapping_transport_class()
+# test_sinkhorn_transport_class()
+# test_emd_transport_class()
+# test_sinkhorn_l1l2_transport_class()
+# test_sinkhorn_lpl1_transport_class()
+# test_mapping_transport_class()