summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index ba0cfd9..b86ef3b 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -1,5 +1,5 @@
"""
-Simple example datasets for OT
+Simple example datasets
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -30,7 +30,7 @@ def make_1D_gauss(n, m, s):
1D histogram for a gaussian distribution
"""
x = np.arange(n, dtype=np.float64)
- h = np.exp(-(x - m)**2 / (2 * s**2))
+ h = np.exp(-(x - m) ** 2 / (2 * s ** 2))
return h / h.sum()
@@ -80,7 +80,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
return make_2D_samples_gauss(n, m, sigma, random_state=None)
-def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
+def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwargs):
"""Dataset generation for classification problems
Parameters
@@ -91,6 +91,8 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
number of training samples
nz : float
noise level (>0)
+ p : float
+ proportion of one class in the binary setting
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
@@ -145,11 +147,22 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
n2 = np.sum(y == 2)
x = np.zeros((n, 2))
- x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
- x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)
+ x[y == 1, :] = make_2D_samples_gauss(n1, m1, nz, random_state=generator)
+ x[y == 2, :] = make_2D_samples_gauss(n2, m2, nz, random_state=generator)
x = x.dot(rot)
+ elif dataset.lower() == '2gauss_prop':
+
+ y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n))))
+ x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * np.random.randn(len(y), 2)
+
+ if ('bias' not in kwargs) and ('b' not in kwargs):
+ kwargs['bias'] = np.array([0, 2])
+
+ x[:, 0] += kwargs['bias'][0]
+ x[:, 1] += kwargs['bias'][1]
+
else:
x = np.array(0)
y = np.array(0)