diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-26 11:48:13 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-26 11:48:13 +0200 |
commit | 2bc41ad8bb54c76bade6db2c0e04fa387ff29500 (patch) | |
tree | d4e8aac854e1a3fa3bbd21caed2e748f530ba873 /test | |
parent | 4a45135dfa3f1aeae8b3bdf0c42422f0f60426e8 (diff) |
rng gpu
Diffstat (limited to 'test')
-rw-r--r-- | test/test_gpu.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/test/test_gpu.py b/test/test_gpu.py index 7ae159b..98f59f7 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -14,7 +14,7 @@ except ImportError: @pytest.mark.skipif(nogpu, reason="No GPU available") def test_gpu_sinkhorn(): - np.random.seed(0) + rng = np.random.RandomState(0) def describe_res(r): print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format( @@ -22,8 +22,8 @@ def test_gpu_sinkhorn(): for n_samples in [50, 100, 500, 1000]: print(n_samples) - a = np.random.rand(n_samples // 4, 100) - b = np.random.rand(n_samples, 100) + a = rng.rand(n_samples // 4, 100) + b = rng.rand(n_samples, 100) time1 = time.time() transport = ot.da.OTDA_sinkhorn() transport.fit(a, b) @@ -43,7 +43,8 @@ def test_gpu_sinkhorn(): @pytest.mark.skipif(nogpu, reason="No GPU available") def test_gpu_sinkhorn_lpl1(): - np.random.seed(0) + + rng = np.random.RandomState(0) def describe_res(r): print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}" @@ -51,9 +52,9 @@ def test_gpu_sinkhorn_lpl1(): for n_samples in [50, 100, 500]: print(n_samples) - a = np.random.rand(n_samples // 4, 100) + a = rng.rand(n_samples // 4, 100) labels_a = np.random.randint(10, size=(n_samples // 4)) - b = np.random.rand(n_samples, 100) + b = rng.rand(n_samples, 100) time1 = time.time() transport = ot.da.OTDA_lpl1() transport.fit(a, labels_a, b) |