summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-30 08:59:44 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-30 08:59:44 +0200
commit06eabe7d6bedbbeedf8dfe55fd1f448806f5ef6b (patch)
tree78adb0f8db1f1f415697f0221a6c0e70befea42e /ot/datasets.py
parentfde3d59ca4a8454cbb64f6b43ee4cd4e4030a1c0 (diff)
pep8 + working tests
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index 7d64b3b..79fc290 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -9,6 +9,7 @@ Simple example datasets for OT
import numpy as np
import scipy as sp
+from .utils import check_random_state
def get_1D_gauss(n, m, s):
@@ -60,7 +61,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
n samples drawn from N(m,sigma)
"""
-
+
generator = check_random_state(random_state)
if np.isscalar(sigma):
sigma = np.array([sigma, ])
@@ -98,9 +99,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
labels of the samples
"""
-
+
generator = check_random_state(random_state)
-
+
if dataset.lower() == '3gauss':
y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
x = np.zeros((n, 2))
@@ -140,8 +141,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
n2 = np.sum(y == 2)
x = np.zeros((n, 2))
- x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz,random_state=generator)
- x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz,random_state=generator)
+ x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
+ x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)
x = x.dot(rot)