summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-26 11:48:13 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-26 11:48:13 +0200
commit2bc41ad8bb54c76bade6db2c0e04fa387ff29500 (patch)
treed4e8aac854e1a3fa3bbd21caed2e748f530ba873 /test
parent4a45135dfa3f1aeae8b3bdf0c42422f0f60426e8 (diff)
rng gpu
Diffstat (limited to 'test')
-rw-r--r--test/test_gpu.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 7ae159b..98f59f7 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -14,7 +14,7 @@ except ImportError:
@pytest.mark.skipif(nogpu, reason="No GPU available")
def test_gpu_sinkhorn():
- np.random.seed(0)
+ rng = np.random.RandomState(0)
def describe_res(r):
print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(
@@ -22,8 +22,8 @@ def test_gpu_sinkhorn():
for n_samples in [50, 100, 500, 1000]:
print(n_samples)
- a = np.random.rand(n_samples // 4, 100)
- b = np.random.rand(n_samples, 100)
+ a = rng.rand(n_samples // 4, 100)
+ b = rng.rand(n_samples, 100)
time1 = time.time()
transport = ot.da.OTDA_sinkhorn()
transport.fit(a, b)
@@ -43,7 +43,8 @@ def test_gpu_sinkhorn():
@pytest.mark.skipif(nogpu, reason="No GPU available")
def test_gpu_sinkhorn_lpl1():
- np.random.seed(0)
+
+ rng = np.random.RandomState(0)
def describe_res(r):
print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}"
@@ -51,9 +52,9 @@ def test_gpu_sinkhorn_lpl1():
for n_samples in [50, 100, 500]:
print(n_samples)
- a = np.random.rand(n_samples // 4, 100)
+ a = rng.rand(n_samples // 4, 100)
labels_a = np.random.randint(10, size=(n_samples // 4))
- b = np.random.rand(n_samples, 100)
+ b = rng.rand(n_samples, 100)
time1 = time.time()
transport = ot.da.OTDA_lpl1()
transport.fit(a, labels_a, b)