diff options
Diffstat (limited to 'ot/datasets.py')
-rw-r--r-- | ot/datasets.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index ba0cfd9..a1ca7b6 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -30,7 +30,7 @@ def make_1D_gauss(n, m, s): 1D histogram for a gaussian distribution """ x = np.arange(n, dtype=np.float64) - h = np.exp(-(x - m)**2 / (2 * s**2)) + h = np.exp(-(x - m) ** 2 / (2 * s ** 2)) return h / h.sum() @@ -80,7 +80,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None): return make_2D_samples_gauss(n, m, sigma, random_state=None) -def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): +def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwargs): """Dataset generation for classification problems Parameters @@ -91,6 +91,8 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): number of training samples nz : float noise level (>0) + p : float + proportion of one class in the binary setting random_state : int, RandomState instance or None, optional (default=None) If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; @@ -150,6 +152,17 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): x = x.dot(rot) + elif dataset.lower() == '2gauss_prop': + + y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n)))) + x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * np.random.randn(len(y), 2) + + if ('bias' not in kwargs) and ('b' not in kwargs): + kwargs['bias'] = np.array([0, 2]) + + x[:, 0] += kwargs['bias'][0] + x[:, 1] += kwargs['bias'][1] + else: x = np.array(0) y = np.array(0) |