From fde3d59ca4a8454cbb64f6b43ee4cd4e4030a1c0 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 30 May 2018 08:57:46 +0200 Subject: add random_state --- ot/datasets.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) (limited to 'ot/datasets.py') diff --git a/ot/datasets.py b/ot/datasets.py index e4fe118..7d64b3b 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -36,7 +36,7 @@ def get_1D_gauss(n, m, s): return h / h.sum() -def get_2D_samples_gauss(n, m, sigma): +def get_2D_samples_gauss(n, m, sigma, random_state=None): """return n samples drawn from 2D gaussian N(m,sigma) Parameters @@ -48,7 +48,11 @@ def get_2D_samples_gauss(n, m, sigma): mean value of the gaussian distribution sigma : np.array (2,2) covariance matrix of the gaussian distribution - + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. Returns ------- @@ -56,17 +60,19 @@ def get_2D_samples_gauss(n, m, sigma): n samples drawn from N(m,sigma) """ + + generator = check_random_state(random_state) if np.isscalar(sigma): sigma = np.array([sigma, ]) if len(sigma) > 1: P = sp.linalg.sqrtm(sigma) - res = np.random.randn(n, 2).dot(P) + m + res = generator.randn(n, 2).dot(P) + m else: - res = np.random.randn(n, 2) * np.sqrt(sigma) + m + res = generator.randn(n, 2) * np.sqrt(sigma) + m return res -def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): +def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): """ dataset generation for classification problems Parameters @@ -78,7 +84,11 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): number of training samples nz : float noise level (>0) - + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. Returns ------- @@ -88,6 +98,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): labels of the samples """ + + generator = check_random_state(random_state) + if dataset.lower() == '3gauss': y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 x = np.zeros((n, 2)) @@ -99,8 +112,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): x[y == 3, 0] = 1. x[y == 3, 1] = 0 - x[y != 3, :] += 1.5 * nz * np.random.randn(sum(y != 3), 2) - x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2) + x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) elif dataset.lower() == '3gauss2': y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 @@ -114,8 +127,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): x[y == 3, 0] = 2. x[y == 3, 1] = 0 - x[y != 3, :] += nz * np.random.randn(sum(y != 3), 2) - x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2) + x[y != 3, :] += nz * generator.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) elif dataset.lower() == 'gaussrot': rot = np.array( @@ -127,8 +140,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): n2 = np.sum(y == 2) x = np.zeros((n, 2)) - x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz) - x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz) + x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz,random_state=generator) + x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz,random_state=generator) x = x.dot(rot) -- cgit v1.2.3