From 81118f22197cdf4553427038526c8f730be256d7 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 26 Jul 2017 11:55:57 +0200 Subject: test_ot random state --- test/test_ot.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 9c0acab..7fe665f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -18,9 +18,9 @@ def test_doctest(): def test_emd_emd2(): # test emd and emd2 for simple identity n = 100 - np.random.seed(0) + rng = np.random.RandomState(0) - x = np.random.randn(n, 2) + x = rng.randn(n, 2) u = ot.utils.unif(n) M = ot.dist(x, x) @@ -28,22 +28,22 @@ def test_emd_emd2(): G = ot.emd(u, u, M) # check G is identity - assert np.allclose(G, np.eye(n) / n) + np.testing.assert_allclose(G, np.eye(n) / n) # check constratints - assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn - assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn w = ot.emd2(u, u, M) # check loss=0 - assert np.allclose(w, 0) + np.testing.assert_allclose(w, 0) def test_emd_empty(): # test emd and emd2 for simple identity n = 100 - np.random.seed(0) + rng = np.random.RandomState(0) - x = np.random.randn(n, 2) + x = rng.randn(n, 2) u = ot.utils.unif(n) M = ot.dist(x, x) @@ -51,14 +51,14 @@ def test_emd_empty(): G = ot.emd([], [], M) # check G is identity - assert np.allclose(G, np.eye(n) / n) + np.testing.assert_allclose(G, np.eye(n) / n) # check constratints - assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn - assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn w = ot.emd2([], [], M) # check loss=0 - assert np.allclose(w, 0) + np.testing.assert_allclose(w, 0) def test_emd2_multi(): @@ -66,7 +66,6 @@ def test_emd2_multi(): from ot.datasets import get_1D_gauss as gauss n = 1000 # nb bins - np.random.seed(0) # bin positions x = np.arange(n, dtype=np.float64) @@ -96,4 +95,4 @@ def test_emd2_multi(): emdn = ot.emd2(a, b, M) ot.toc('multi proc : {} s') - assert np.allclose(emd1, emdn) + np.testing.assert_allclose(emd1, emdn) -- cgit v1.2.3