summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:42:09 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:42:09 +0200
commitb1f87363b160735b6e2df59380f9de56b7934b53 (patch)
tree7e36f3cc31dad1ffeefc53c0f41d4802dd8ec54c /ot/datasets.py
parent547a03ef87e4aa92edc1e89ee2db04114e1a8ad5 (diff)
add dataset clean plot
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index ba0cfd9..eea9f37 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -80,7 +80,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
return make_2D_samples_gauss(n, m, sigma, random_state=None)
-def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
+def make_data_classif(dataset, n, nz=.5, theta=0, p = .5, random_state=None, **kwargs):
"""Dataset generation for classification problems
Parameters
@@ -91,6 +91,8 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
number of training samples
nz : float
noise level (>0)
+ p : float
+ proportion of one class in the binary setting
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
@@ -150,6 +152,17 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
x = x.dot(rot)
+ elif dataset.lower() == '2gauss_prop':
+
+ y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n))))
+ x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * np.random.randn(len(y), 2)
+
+ if ('bias' not in kwargs) and ('b' not in kwargs):
+ kwargs['bias'] = np.array([0, 2])
+
+ x[:, 0] += kwargs['bias'][0]
+ x[:, 1] += kwargs['bias'][1]
+
else:
x = np.array(0)
y = np.array(0)