diff options
Diffstat (limited to 'ot/datasets.py')
-rw-r--r-- | ot/datasets.py | 60 |
1 files changed, 46 insertions, 14 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index e4fe118..362a89b 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -9,9 +9,10 @@ Simple example datasets for OT import numpy as np import scipy as sp +from .utils import check_random_state, deprecated -def get_1D_gauss(n, m, s): +def make_1D_gauss(n, m, s): """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) Parameters @@ -36,19 +37,29 @@ def get_1D_gauss(n, m, s): return h / h.sum() -def get_2D_samples_gauss(n, m, sigma): +@deprecated() +def get_1D_gauss(n, m, sigma, random_state=None): + """ Deprecated see make_1D_gauss """ + return make_1D_gauss(n, m, sigma, random_state=None) + + +def make_2D_samples_gauss(n, m, sigma, random_state=None): """return n samples drawn from 2D gaussian N(m,sigma) Parameters ---------- n : int - number of bins in the histogram + number of samples to make m : np.array (2,) mean value of the gaussian distribution sigma : np.array (2,2) covariance matrix of the gaussian distribution - + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. Returns ------- @@ -56,17 +67,25 @@ def get_2D_samples_gauss(n, m, sigma): n samples drawn from N(m,sigma) """ + + generator = check_random_state(random_state) if np.isscalar(sigma): sigma = np.array([sigma, ]) if len(sigma) > 1: P = sp.linalg.sqrtm(sigma) - res = np.random.randn(n, 2).dot(P) + m + res = generator.randn(n, 2).dot(P) + m else: - res = np.random.randn(n, 2) * np.sqrt(sigma) + m + res = generator.randn(n, 2) * np.sqrt(sigma) + m return res -def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): +@deprecated() +def get_2D_samples_gauss(n, m, sigma, random_state=None): + """ Deprecated see make_2D_samples_gauss """ + return make_2D_samples_gauss(n, m, sigma, random_state=None) + + +def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): """ dataset generation for classification problems Parameters @@ -78,7 +97,11 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): number of training samples nz : float noise level (>0) - + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. Returns ------- @@ -88,6 +111,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): labels of the samples """ + + generator = check_random_state(random_state) + if dataset.lower() == '3gauss': y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 x = np.zeros((n, 2)) @@ -99,8 +125,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): x[y == 3, 0] = 1. x[y == 3, 1] = 0 - x[y != 3, :] += 1.5 * nz * np.random.randn(sum(y != 3), 2) - x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2) + x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) elif dataset.lower() == '3gauss2': y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 @@ -114,8 +140,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): x[y == 3, 0] = 2. x[y == 3, 1] = 0 - x[y != 3, :] += nz * np.random.randn(sum(y != 3), 2) - x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2) + x[y != 3, :] += nz * generator.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) elif dataset.lower() == 'gaussrot': rot = np.array( @@ -127,8 +153,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): n2 = np.sum(y == 2) x = np.zeros((n, 2)) - x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz) - x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz) + x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator) + x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator) x = x.dot(rot) @@ -138,3 +164,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): print("unknown dataset") return x, y.astype(int) + + +@deprecated() +def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): + """ Deprecated see make_data_classif """ + return make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs) |