summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-07-23 10:32:41 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-07-23 10:32:41 +0200
commit92292231a4d9661c399dbfd97b22d6f7f890f698 (patch)
tree248945b95f655fd51997a895a61eeed24988aa72 /ot/datasets.py
parent50a5a4111ada5e8c208da1acf731608930d0a278 (diff)
parent0063cb87a10293a24ad1c9483be121745958c24a (diff)
Merge branch 'master' of https://github.com/rflamary/POT into fix_mismatch_error_94
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py32
1 files changed, 12 insertions, 20 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index e76e75d..ba0cfd9 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -17,7 +17,6 @@ def make_1D_gauss(n, m, s):
Parameters
----------
-
n : int
number of bins in the histogram
m : float
@@ -25,12 +24,10 @@ def make_1D_gauss(n, m, s):
s : float
standard deviaton of the gaussian distribution
-
Returns
-------
- h : np.array (n,)
- 1D histogram for a gaussian distribution
-
+ h : ndarray (n,)
+ 1D histogram for a gaussian distribution
"""
x = np.arange(n, dtype=np.float64)
h = np.exp(-(x - m)**2 / (2 * s**2))
@@ -44,16 +41,15 @@ def get_1D_gauss(n, m, sigma):
def make_2D_samples_gauss(n, m, sigma, random_state=None):
- """return n samples drawn from 2D gaussian N(m,sigma)
+ """Return n samples drawn from 2D gaussian N(m,sigma)
Parameters
----------
-
n : int
number of samples to make
- m : np.array (2,)
+ m : ndarray, shape (2,)
mean value of the gaussian distribution
- sigma : np.array (2,2)
+ sigma : ndarray, shape (2, 2)
covariance matrix of the gaussian distribution
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
@@ -63,9 +59,8 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None):
Returns
-------
- X : np.array (n,2)
- n samples drawn from N(m,sigma)
-
+ X : ndarray, shape (n, 2)
+ n samples drawn from N(m, sigma).
"""
generator = check_random_state(random_state)
@@ -86,11 +81,10 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
- """ dataset generation for classification problems
+ """Dataset generation for classification problems
Parameters
----------
-
dataset : str
type of classification problem (see code)
n : int
@@ -105,13 +99,11 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
Returns
-------
- X : np.array (n,d)
- n observation of size d
- y : np.array (n,)
- labels of the samples
-
+ X : ndarray, shape (n, d)
+ n observation of size d
+ y : ndarray, shape (n,)
+ labels of the samples.
"""
-
generator = check_random_state(random_state)
if dataset.lower() == '3gauss':