summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-25 10:29:41 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:09:13 +0200
commit6167f34a721886d4b9038a8b1746a2c8c81132ce (patch)
tree5f0ead47c2d0bb769edca3e0f9e8f155ea73ffd0 /test
parentfc58f39fc730a9e1bb2215ef063e37c50f0ebc1f (diff)
solving log issues to avoid errors and adding further tests
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py39
1 files changed, 33 insertions, 6 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 9578b3d..104a798 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -26,6 +26,8 @@ def test_sinkhorn_lpl1_transport_class():
# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -89,6 +91,9 @@ def test_sinkhorn_l1l2_transport_class():
# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "log_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -137,6 +142,11 @@ def test_sinkhorn_l1l2_transport_class():
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check everything runs well with log=True
+ clf = ot.da.SinkhornL1l2Transport(log=True)
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_sinkhorn_transport_class():
"""test_sinkhorn_transport
@@ -152,6 +162,9 @@ def test_sinkhorn_transport_class():
# test its computed
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "log_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -200,6 +213,11 @@ def test_sinkhorn_transport_class():
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check everything runs well with log=True
+ clf = ot.da.SinkhornTransport(log=True)
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_emd_transport_class():
"""test_sinkhorn_transport
@@ -215,6 +233,8 @@ def test_emd_transport_class():
# test its computed
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -282,6 +302,9 @@ def test_mapping_transport_class():
# check computation and dimensions if bias == False
clf = ot.da.MappingTransport(kernel="linear", bias=False)
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "mapping_")
+ assert hasattr(clf, "log_")
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
@@ -369,6 +392,11 @@ def test_mapping_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
+ # check everything runs well with log=True
+ clf = ot.da.MappingTransport(kernel="gaussian", log=True)
+ clf.fit(Xs=Xs, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_otda():
@@ -434,9 +462,8 @@ def test_otda():
# if __name__ == "__main__":
- # test_otda()
- # test_sinkhorn_transport_class()
- # test_emd_transport_class()
- # test_sinkhorn_l1l2_transport_class()
- # test_sinkhorn_lpl1_transport_class()
- # test_mapping_transport_class()
+# test_sinkhorn_transport_class()
+# test_emd_transport_class()
+# test_sinkhorn_l1l2_transport_class()
+# test_sinkhorn_lpl1_transport_class()
+# test_mapping_transport_class()