summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-23 13:50:24 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-23 13:50:24 +0200
commit791a4a6f215033a75d5f56cd16fe2412301bec14 (patch)
tree9e1d54a85d6cae07ab745363114b8e662dc7dce9 /test/test_da.py
parent8149e059be7f715834d11b365855f2684bd3d6f5 (diff)
out of samples transform and inverse transform by batch
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py66
1 files changed, 33 insertions, 33 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 93f7e83..196f4c4 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -28,14 +28,14 @@ def test_sinkhorn_lpl1_transport_class():
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
# test dimensions of coupling
- assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = clf.transform(Xs=Xs)
@@ -64,13 +64,13 @@ def test_sinkhorn_lpl1_transport_class():
# test semi supervised mode
clf = ot.da.SinkhornLpl1Transport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(clf.Cost)
+ n_unsup = np.sum(clf.cost_)
# test semi supervised mode
clf = ot.da.SinkhornLpl1Transport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
- assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(clf.Cost)
+ assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.cost_)
assert n_unsup != n_semisup, "semisupervised mode not working"
@@ -91,14 +91,14 @@ def test_sinkhorn_l1l2_transport_class():
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
# test dimensions of coupling
- assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = clf.transform(Xs=Xs)
@@ -127,13 +127,13 @@ def test_sinkhorn_l1l2_transport_class():
# test semi supervised mode
clf = ot.da.SinkhornL1l2Transport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(clf.Cost)
+ n_unsup = np.sum(clf.cost_)
# test semi supervised mode
clf = ot.da.SinkhornL1l2Transport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
- assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(clf.Cost)
+ assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.cost_)
assert n_unsup != n_semisup, "semisupervised mode not working"
@@ -154,14 +154,14 @@ def test_sinkhorn_transport_class():
clf.fit(Xs=Xs, Xt=Xt)
# test dimensions of coupling
- assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = clf.transform(Xs=Xs)
@@ -190,13 +190,13 @@ def test_sinkhorn_transport_class():
# test semi supervised mode
clf = ot.da.SinkhornTransport()
clf.fit(Xs=Xs, Xt=Xt)
- n_unsup = np.sum(clf.Cost)
+ n_unsup = np.sum(clf.cost_)
# test semi supervised mode
clf = ot.da.SinkhornTransport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
- assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(clf.Cost)
+ assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.cost_)
assert n_unsup != n_semisup, "semisupervised mode not working"
@@ -217,14 +217,14 @@ def test_emd_transport_class():
clf.fit(Xs=Xs, Xt=Xt)
# test dimensions of coupling
- assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = clf.transform(Xs=Xs)
@@ -253,13 +253,13 @@ def test_emd_transport_class():
# test semi supervised mode
clf = ot.da.EMDTransport()
clf.fit(Xs=Xs, Xt=Xt)
- n_unsup = np.sum(clf.Cost)
+ n_unsup = np.sum(clf.cost_)
# test semi supervised mode
clf = ot.da.EMDTransport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
- assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(clf.Cost)
+ assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(clf.cost_)
assert n_unsup != n_semisup, "semisupervised mode not working"
@@ -326,9 +326,9 @@ def test_otda():
da_emd.predict(xs) # interpolation of source samples
-if __name__ == "__main__":
+# if __name__ == "__main__":
- test_sinkhorn_transport_class()
- test_emd_transport_class()
- test_sinkhorn_l1l2_transport_class()
- test_sinkhorn_lpl1_transport_class()
+# test_sinkhorn_transport_class()
+# test_emd_transport_class()
+# test_sinkhorn_l1l2_transport_class()
+# test_sinkhorn_lpl1_transport_class()