From 90e42f32bdf0dd06667edaf172c51f4d4fce2c8b Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 30 May 2018 09:30:21 +0200 Subject: replace function name tin tests --- test/test_gromov.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'test/test_gromov.py') diff --git a/test/test_gromov.py b/test/test_gromov.py index bb23469..fb86274 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -15,7 +15,7 @@ def test_gromov(): mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) xt = xs[::-1].copy() @@ -55,7 +55,7 @@ def test_entropic_gromov(): mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) xt = xs[::-1].copy() @@ -96,8 +96,8 @@ def test_gromov_barycenter(): ns = 50 nt = 60 - Xs, ys = ot.datasets.get_data_classif('3gauss', ns) - Xt, yt = ot.datasets.get_data_classif('3gauss2', nt) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -123,8 +123,8 @@ def test_gromov_entropic_barycenter(): ns = 50 nt = 60 - Xs, ys = ot.datasets.get_data_classif('3gauss', ns) - Xt, yt = ot.datasets.get_data_classif('3gauss2', nt) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -133,13 +133,13 @@ def test_gromov_entropic_barycenter(): Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], - 'square_loss', 1e-3, + 'square_loss', 2e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], [ot.unif(ns), ot.unif(nt) ], ot.unif(n_samples), [.5, .5], - 'kl_loss', 1e-3, + 'kl_loss', 2e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) -- cgit v1.2.3