diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-01 09:42:09 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-01 09:42:09 +0200 |
commit | b1f87363b160735b6e2df59380f9de56b7934b53 (patch) | |
tree | 7e36f3cc31dad1ffeefc53c0f41d4802dd8ec54c /ot | |
parent | 547a03ef87e4aa92edc1e89ee2db04114e1a8ad5 (diff) |
add dataset clean plot
Diffstat (limited to 'ot')
-rw-r--r-- | ot/datasets.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index ba0cfd9..eea9f37 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -80,7 +80,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None): return make_2D_samples_gauss(n, m, sigma, random_state=None) -def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): +def make_data_classif(dataset, n, nz=.5, theta=0, p = .5, random_state=None, **kwargs): """Dataset generation for classification problems Parameters @@ -91,6 +91,8 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): number of training samples nz : float noise level (>0) + p : float + proportion of one class in the binary setting random_state : int, RandomState instance or None, optional (default=None) If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; @@ -150,6 +152,17 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): x = x.dot(rot) + elif dataset.lower() == '2gauss_prop': + + y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n)))) + x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * np.random.randn(len(y), 2) + + if ('bias' not in kwargs) and ('b' not in kwargs): + kwargs['bias'] = np.array([0, 2]) + + x[:, 0] += kwargs['bias'][0] + x[:, 1] += kwargs['bias'][1] + else: x = np.array(0) y = np.array(0) |