summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 12:50:54 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 12:50:54 +0200
commit9523b1e95ed7c117554ff673532d150841092137 (patch)
tree1ed5b5060b3b59e79ca2e286058162cdd51c3b75 /ot/datasets.py
parentf33087d2b1790dac773782bb0d91bcfe7ce6a079 (diff)
doc datasets.py
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py71
1 files changed, 60 insertions, 11 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index f22e345..6388d94 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -8,14 +8,50 @@ import scipy as sp
def get_1D_gauss(n,m,s):
- "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) "
+ """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
+
+ Parameters
+ ----------
+
+ n : int
+ number of bins in the histogram
+ m : float
+ mean value of the gaussian distribution
+ s : float
+ standard deviaton of the gaussian distribution
+
+
+ Returns
+ -------
+ h : np.array (n,)
+ 1D histogram for a gaussian distribution
+
+ """
x=np.arange(n,dtype=np.float64)
h=np.exp(-(x-m)**2/(2*s^2))
return h/h.sum()
def get_2D_samples_gauss(n,m,sigma):
- "return samples from 2D gaussian (n samples, mean m and cov sigma) "
+ """return n samples drawn from 2D gaussian N(m,sigma)
+
+ Parameters
+ ----------
+
+ n : int
+ number of bins in the histogram
+ m : np.array (2,)
+ mean value of the gaussian distribution
+ sigma : np.array (2,2)
+ covariance matrix of the gaussian distribution
+
+
+ Returns
+ -------
+ X : np.array (n,2)
+ n samples drawn from N(m,sigma)
+
+ """
if np.isscalar(sigma):
sigma=np.array([sigma,])
if len(sigma)>1:
@@ -26,8 +62,26 @@ def get_2D_samples_gauss(n,m,sigma):
return res
def get_data_classif(dataset,n,nz=.5,**kwargs):
- """
- dataset generation
+ """ dataset generation for classification problems
+
+ Parameters
+ ----------
+
+ dataset : str
+ type of classification problem (see code)
+ n : int
+ number of training samples
+ nz : float
+ noise level (>0)
+
+
+ Returns
+ -------
+ X : np.array (n,d)
+ n observation of size d
+ y : np.array (n,)
+ labels of the samples
+
"""
if dataset.lower()=='3gauss':
y=np.floor((np.arange(n)*1.0/n*3))+1
@@ -50,15 +104,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
x[y==3,0]=2. ; x[y==3,1]=0
x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
- x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
-# elif dataset.lower()=='sinreg':
-#
-# x=np.random.rand(n,1)
-# y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1)
-
+ x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
else:
x=0
y=0
print("unknown dataset")
- return x,y \ No newline at end of file
+ return x,y.astype(int) \ No newline at end of file