From 2097116c7db725a88876d617e20a94f32627f7c9 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Tue, 5 Sep 2017 10:09:55 +0200 Subject: solving pb --- test/test_da.py | 61 ++++++++++++++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 27 deletions(-) (limited to 'test') diff --git a/test/test_da.py b/test/test_da.py index 3602db9..593dc53 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -36,8 +36,10 @@ def test_sinkhorn_lpl1_transport_class(): # test margin constraints mu_s = unif(ns) mu_t = unif(nt) - assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -108,8 +110,10 @@ def test_sinkhorn_l1l2_transport_class(): # test margin constraints mu_s = unif(ns) mu_t = unif(nt) - assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -187,8 +191,10 @@ def test_sinkhorn_transport_class(): # test margin constraints mu_s = unif(ns) mu_t = unif(nt) - assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -263,8 +269,10 @@ def test_emd_transport_class(): # test margin constraints mu_s = unif(ns) mu_t = unif(nt) - assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -342,8 +350,10 @@ def test_mapping_transport_class(): # test margin constraints mu_s = unif(ns) mu_t = unif(nt) - assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -363,8 +373,10 @@ def test_mapping_transport_class(): # test margin constraints mu_s = unif(ns) mu_t = unif(nt) - assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -389,8 +401,10 @@ def test_mapping_transport_class(): # test margin constraints mu_s = unif(ns) mu_t = unif(nt) - assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -410,8 +424,10 @@ def test_mapping_transport_class(): # test margin constraints mu_s = unif(ns) mu_t = unif(nt) - assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -454,7 +470,8 @@ def test_otda(): da_entrop.interp() da_entrop.predict(xs) - np.testing.assert_allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose( + a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3) np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3) # non-convex Group lasso regularization @@ -488,13 +505,3 @@ def test_otda(): da_emd = ot.da.OTDA_mapping_kernel() # init class da_emd.fit(xs, xt, numItermax=10) # fit distributions da_emd.predict(xs) # interpolation of source samples - - -# if __name__ == "__main__": - -# test_sinkhorn_transport_class() -# test_emd_transport_class() -# test_sinkhorn_l1l2_transport_class() -# test_sinkhorn_lpl1_transport_class() -# test_mapping_transport_class() - -- cgit v1.2.3