summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-26 11:55:57 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-26 11:55:57 +0200
commit81118f22197cdf4553427038526c8f730be256d7 (patch)
tree895a08ed3b0c1fed480a732db95dcaa058317bbd /test/test_ot.py
parent0e06129613c9b905f28b4d302cddb872cef6c234 (diff)
test_ot random state
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py27
1 files changed, 13 insertions, 14 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 9c0acab..7fe665f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -18,9 +18,9 @@ def test_doctest():
def test_emd_emd2():
# test emd and emd2 for simple identity
n = 100
- np.random.seed(0)
+ rng = np.random.RandomState(0)
- x = np.random.randn(n, 2)
+ x = rng.randn(n, 2)
u = ot.utils.unif(n)
M = ot.dist(x, x)
@@ -28,22 +28,22 @@ def test_emd_emd2():
G = ot.emd(u, u, M)
# check G is identity
- assert np.allclose(G, np.eye(n) / n)
+ np.testing.assert_allclose(G, np.eye(n) / n)
# check constratints
- assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn
- assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
w = ot.emd2(u, u, M)
# check loss=0
- assert np.allclose(w, 0)
+ np.testing.assert_allclose(w, 0)
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
- np.random.seed(0)
+ rng = np.random.RandomState(0)
- x = np.random.randn(n, 2)
+ x = rng.randn(n, 2)
u = ot.utils.unif(n)
M = ot.dist(x, x)
@@ -51,14 +51,14 @@ def test_emd_empty():
G = ot.emd([], [], M)
# check G is identity
- assert np.allclose(G, np.eye(n) / n)
+ np.testing.assert_allclose(G, np.eye(n) / n)
# check constratints
- assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn
- assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
w = ot.emd2([], [], M)
# check loss=0
- assert np.allclose(w, 0)
+ np.testing.assert_allclose(w, 0)
def test_emd2_multi():
@@ -66,7 +66,6 @@ def test_emd2_multi():
from ot.datasets import get_1D_gauss as gauss
n = 1000 # nb bins
- np.random.seed(0)
# bin positions
x = np.arange(n, dtype=np.float64)
@@ -96,4 +95,4 @@ def test_emd2_multi():
emdn = ot.emd2(a, b, M)
ot.toc('multi proc : {} s')
- assert np.allclose(emd1, emdn)
+ np.testing.assert_allclose(emd1, emdn)