summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 15:00:50 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 15:00:50 +0200
commitfa989062c17f87bd96aa58ad764fd3791ea11e22 (patch)
tree0a6c7e571967c17aafb144ba018e063a2e43d070 /test
parent63bbeb34e48f02c97a762dab5232158d90a5cffc (diff)
Reame +pep8
Diffstat (limited to 'test')
-rw-r--r--test/test_gromov.py53
-rw-r--r--test/test_optim.py9
2 files changed, 32 insertions, 30 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 43b63e1..cd180d4 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -145,7 +145,8 @@ def test_gromov_entropic_barycenter():
'kl_loss', 2e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
-
+
+
def test_fgw():
n_samples = 50 # nb samples
@@ -155,9 +156,9 @@ def test_fgw():
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
xt = xs[::-1].copy()
-
- ys = np.random.randn(xs.shape[0],2)
- yt= ys[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
p = ot.unif(n_samples)
q = ot.unif(n_samples)
@@ -167,11 +168,11 @@ def test_fgw():
C1 /= C1.max()
C2 /= C2.max()
-
- M=ot.dist(ys,yt)
- M/=M.max()
- G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5)
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
# check constratints
np.testing.assert_allclose(
@@ -187,36 +188,36 @@ def test_fgw_barycenter():
Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
-
- ys = np.random.randn(Xs.shape[0],2)
- yt= np.random.randn(Xt.shape[0],2)
+
+ ys = np.random.randn(Xs.shape[0], 2)
+ yt = np.random.randn(Xt.shape[0], 2)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
n_samples = 3
- X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
- fixed_structure=False,fixed_features=False,
- p=ot.unif(n_samples),loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
-
- X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5,
- fixed_structure=True,init_C=init_C,fixed_features=False,
- p=ot.unif(n_samples),loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+
+ X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
+ fixed_structure=True, init_C=init_C, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
-
- init_X=np.random.randn(n_samples,ys.shape[1])
- X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
- fixed_structure=False,fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples),loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ init_X = np.random.randn(n_samples, ys.shape[1])
+
+ X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_optim.py b/test/test_optim.py
index 1188ef6..e7ba32a 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -65,8 +65,9 @@ def test_generalized_conditional_gradient():
np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
-
+
+
def test_solve_1d_linesearch_quad_funct():
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5)
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0)
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1, -1, 0), 0.5)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 5, 0), 0)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 0.5, 0), 1)