From fa989062c17f87bd96aa58ad764fd3791ea11e22 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 15:00:50 +0200 Subject: Reame +pep8 --- test/test_gromov.py | 53 +++++++++++++++++++++++++++-------------------------- test/test_optim.py | 9 +++++---- 2 files changed, 32 insertions(+), 30 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index 43b63e1..cd180d4 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -145,7 +145,8 @@ def test_gromov_entropic_barycenter(): 'kl_loss', 2e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) - + + def test_fgw(): n_samples = 50 # nb samples @@ -155,9 +156,9 @@ def test_fgw(): xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) xt = xs[::-1].copy() - - ys = np.random.randn(xs.shape[0],2) - yt= ys[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() p = ot.unif(n_samples) q = ot.unif(n_samples) @@ -167,11 +168,11 @@ def test_fgw(): C1 /= C1.max() C2 /= C2.max() - - M=ot.dist(ys,yt) - M/=M.max() - G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5) + M = ot.dist(ys, yt) + M /= M.max() + + G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5) # check constratints np.testing.assert_allclose( @@ -187,36 +188,36 @@ def test_fgw_barycenter(): Xs, ys = ot.datasets.make_data_classif('3gauss', ns) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) - - ys = np.random.randn(Xs.shape[0],2) - yt= np.random.randn(Xt.shape[0],2) + + ys = np.random.randn(Xs.shape[0], 2) + yt = np.random.randn(Xt.shape[0], 2) C1 = ot.dist(Xs) C2 = ot.dist(Xt) n_samples = 3 - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, - fixed_structure=False,fixed_features=False, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5, - fixed_structure=True,init_C=init_C,fixed_features=False, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, + fixed_structure=True, init_C=init_C, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) - - init_X=np.random.randn(n_samples,ys.shape[1]) - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, - fixed_structure=False,fixed_features=True, init_X=init_X, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + init_X = np.random.randn(n_samples, ys.shape[1]) + + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) diff --git a/test/test_optim.py b/test/test_optim.py index 1188ef6..e7ba32a 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -65,8 +65,9 @@ def test_generalized_conditional_gradient(): np.testing.assert_allclose(a, G.sum(1), atol=1e-05) np.testing.assert_allclose(b, G.sum(0), atol=1e-05) - + + def test_solve_1d_linesearch_quad_funct(): - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1, -1, 0), 0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 5, 0), 0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 0.5, 0), 1) -- cgit v1.2.3