summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-09-05 10:09:55 +0200
committerSlasnista <stan.chambon@gmail.com>2017-09-05 10:09:55 +0200
commit2097116c7db725a88876d617e20a94f32627f7c9 (patch)
tree61bd6b7408dd714ee6643c033bba745af0197059 /test/test_da.py
parent49c100de34583329058b39d414d2aa49b7fd15bf (diff)
solving pb
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py61
1 files changed, 34 insertions, 27 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 3602db9..593dc53 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -36,8 +36,10 @@ def test_sinkhorn_lpl1_transport_class():
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -108,8 +110,10 @@ def test_sinkhorn_l1l2_transport_class():
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -187,8 +191,10 @@ def test_sinkhorn_transport_class():
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -263,8 +269,10 @@ def test_emd_transport_class():
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -342,8 +350,10 @@ def test_mapping_transport_class():
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -363,8 +373,10 @@ def test_mapping_transport_class():
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -389,8 +401,10 @@ def test_mapping_transport_class():
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -410,8 +424,10 @@ def test_mapping_transport_class():
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -454,7 +470,8 @@ def test_otda():
da_entrop.interp()
da_entrop.predict(xs)
- np.testing.assert_allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(
+ a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
# non-convex Group lasso regularization
@@ -488,13 +505,3 @@ def test_otda():
da_emd = ot.da.OTDA_mapping_kernel() # init class
da_emd.fit(xs, xt, numItermax=10) # fit distributions
da_emd.predict(xs) # interpolation of source samples
-
-
-# if __name__ == "__main__":
-
-# test_sinkhorn_transport_class()
-# test_emd_transport_class()
-# test_sinkhorn_l1l2_transport_class()
-# test_sinkhorn_lpl1_transport_class()
-# test_mapping_transport_class()
-