summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-03 14:53:52 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-03 14:53:52 +0100
commit566645ad184e1205f7f666ea2f19021254c33d74 (patch)
tree1d740a6771ab515d0cbfe9f21fde801398eb19b6 /ot/datasets.py
parent981351165dbab740145d109b00782f0c41f2244b (diff)
add mapping estimation (still debugging)
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py62
1 files changed, 40 insertions, 22 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index 6388d94..588f501 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -8,8 +8,8 @@ import scipy as sp
def get_1D_gauss(n,m,s):
- """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
-
+ """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
+
Parameters
----------
@@ -20,21 +20,21 @@ def get_1D_gauss(n,m,s):
s : float
standard deviaton of the gaussian distribution
-
+
Returns
-------
h : np.array (n,)
- 1D histogram for a gaussian distribution
-
+ 1D histogram for a gaussian distribution
+
"""
x=np.arange(n,dtype=np.float64)
h=np.exp(-(x-m)**2/(2*s^2))
return h/h.sum()
-
-
+
+
def get_2D_samples_gauss(n,m,sigma):
- """return n samples drawn from 2D gaussian N(m,sigma)
-
+ """return n samples drawn from 2D gaussian N(m,sigma)
+
Parameters
----------
@@ -45,12 +45,12 @@ def get_2D_samples_gauss(n,m,sigma):
sigma : np.array (2,2)
covariance matrix of the gaussian distribution
-
+
Returns
-------
X : np.array (n,2)
- n samples drawn from N(m,sigma)
-
+ n samples drawn from N(m,sigma)
+
"""
if np.isscalar(sigma):
sigma=np.array([sigma,])
@@ -61,9 +61,10 @@ def get_2D_samples_gauss(n,m,sigma):
res= np.random.randn(n,2)*np.sqrt(sigma)+m
return res
-def get_data_classif(dataset,n,nz=.5,**kwargs):
+
+def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
""" dataset generation for classification problems
-
+
Parameters
----------
@@ -74,13 +75,13 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
nz : float
noise level (>0)
-
+
Returns
-------
X : np.array (n,d)
- n observation of size d
+ n observation of size d
y : np.array (n,)
- labels of the samples
+ labels of the samples
"""
if dataset.lower()=='3gauss':
@@ -90,10 +91,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
x[y==1,0]=-1.; x[y==1,1]=-1.
x[y==2,0]=-1.; x[y==2,1]=1.
x[y==3,0]=1. ; x[y==3,1]=0
-
+
x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2)
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
-
+
elif dataset.lower()=='3gauss2':
y=np.floor((np.arange(n)*1.0/n*3))+1
x=np.zeros((n,2))
@@ -102,12 +103,29 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
x[y==1,0]=-2.; x[y==1,1]=-2.
x[y==2,0]=-2.; x[y==2,1]=2.
x[y==3,0]=2. ; x[y==3,1]=0
-
+
x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
- x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
+ x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
+
+ elif dataset.lower()=='gaussrot' :
+ rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
+ m1=np.array([-1,-1])
+ m2=np.array([1,1])
+ y=np.floor((np.arange(n)*1.0/n*2))+1
+ n1=np.sum(y==1)
+ n2=np.sum(y==2)
+ x=np.zeros((n,2))
+
+ x[y==1,:]=get_2D_samples_gauss(n1,m1,nz)
+ x[y==2,:]=get_2D_samples_gauss(n2,m2,nz)
+
+ x=x.dot(rot)
+
+
+
else:
x=0
y=0
print("unknown dataset")
-
+
return x,y.astype(int) \ No newline at end of file