diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 14:53:52 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 14:53:52 +0100 |
commit | 566645ad184e1205f7f666ea2f19021254c33d74 (patch) | |
tree | 1d740a6771ab515d0cbfe9f21fde801398eb19b6 /ot/datasets.py | |
parent | 981351165dbab740145d109b00782f0c41f2244b (diff) |
add mapping estimation (still debugging)
Diffstat (limited to 'ot/datasets.py')
-rw-r--r-- | ot/datasets.py | 62 |
1 files changed, 40 insertions, 22 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index 6388d94..588f501 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -8,8 +8,8 @@ import scipy as sp def get_1D_gauss(n,m,s): - """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) - + """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) + Parameters ---------- @@ -20,21 +20,21 @@ def get_1D_gauss(n,m,s): s : float standard deviaton of the gaussian distribution - + Returns ------- h : np.array (n,) - 1D histogram for a gaussian distribution - + 1D histogram for a gaussian distribution + """ x=np.arange(n,dtype=np.float64) h=np.exp(-(x-m)**2/(2*s^2)) return h/h.sum() - - + + def get_2D_samples_gauss(n,m,sigma): - """return n samples drawn from 2D gaussian N(m,sigma) - + """return n samples drawn from 2D gaussian N(m,sigma) + Parameters ---------- @@ -45,12 +45,12 @@ def get_2D_samples_gauss(n,m,sigma): sigma : np.array (2,2) covariance matrix of the gaussian distribution - + Returns ------- X : np.array (n,2) - n samples drawn from N(m,sigma) - + n samples drawn from N(m,sigma) + """ if np.isscalar(sigma): sigma=np.array([sigma,]) @@ -61,9 +61,10 @@ def get_2D_samples_gauss(n,m,sigma): res= np.random.randn(n,2)*np.sqrt(sigma)+m return res -def get_data_classif(dataset,n,nz=.5,**kwargs): + +def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs): """ dataset generation for classification problems - + Parameters ---------- @@ -74,13 +75,13 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): nz : float noise level (>0) - + Returns ------- X : np.array (n,d) - n observation of size d + n observation of size d y : np.array (n,) - labels of the samples + labels of the samples """ if dataset.lower()=='3gauss': @@ -90,10 +91,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==1,0]=-1.; x[y==1,1]=-1. x[y==2,0]=-1.; x[y==2,1]=1. x[y==3,0]=1. ; x[y==3,1]=0 - + x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) - + elif dataset.lower()=='3gauss2': y=np.floor((np.arange(n)*1.0/n*3))+1 x=np.zeros((n,2)) @@ -102,12 +103,29 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==1,0]=-2.; x[y==1,1]=-2. x[y==2,0]=-2.; x[y==2,1]=2. x[y==3,0]=2. ; x[y==3,1]=0 - + x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) - x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) + x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) + + elif dataset.lower()=='gaussrot' : + rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]) + m1=np.array([-1,-1]) + m2=np.array([1,1]) + y=np.floor((np.arange(n)*1.0/n*2))+1 + n1=np.sum(y==1) + n2=np.sum(y==2) + x=np.zeros((n,2)) + + x[y==1,:]=get_2D_samples_gauss(n1,m1,nz) + x[y==2,:]=get_2D_samples_gauss(n2,m2,nz) + + x=x.dot(rot) + + + else: x=0 y=0 print("unknown dataset") - + return x,y.astype(int)
\ No newline at end of file |