diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-05-30 08:59:44 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-05-30 08:59:44 +0200 |
commit | 06eabe7d6bedbbeedf8dfe55fd1f448806f5ef6b (patch) | |
tree | 78adb0f8db1f1f415697f0221a6c0e70befea42e /ot/datasets.py | |
parent | fde3d59ca4a8454cbb64f6b43ee4cd4e4030a1c0 (diff) |
pep8 + working tests
Diffstat (limited to 'ot/datasets.py')
-rw-r--r-- | ot/datasets.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index 7d64b3b..79fc290 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -9,6 +9,7 @@ Simple example datasets for OT import numpy as np import scipy as sp +from .utils import check_random_state def get_1D_gauss(n, m, s): @@ -60,7 +61,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None): n samples drawn from N(m,sigma) """ - + generator = check_random_state(random_state) if np.isscalar(sigma): sigma = np.array([sigma, ]) @@ -98,9 +99,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): labels of the samples """ - + generator = check_random_state(random_state) - + if dataset.lower() == '3gauss': y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 x = np.zeros((n, 2)) @@ -140,8 +141,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): n2 = np.sum(y == 2) x = np.zeros((n, 2)) - x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz,random_state=generator) - x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz,random_state=generator) + x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator) + x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator) x = x.dot(rot) |