diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 17:12:01 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-24 17:12:01 +0200 |
commit | 041f9bca2e89d1d111b553159aed862291630b00 (patch) | |
tree | 8ce24333e2f916592e7667fa8e0cbc8ebc6dda4a /ot/datasets.py | |
parent | 176ff069483b9ba630af8a00ce5edc104168c0a2 (diff) |
add domain adaptation
Diffstat (limited to 'ot/datasets.py')
-rw-r--r-- | ot/datasets.py | 39 |
1 files changed, 38 insertions, 1 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index cebfdac..edc29a9 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -24,4 +24,41 @@ def get_2D_samples_gauss(n,m,sigma): else: res= np.random.randn(n,2)*np.sqrt(sigma)+m return res -
\ No newline at end of file + +def get_data_classif(dataset,n,nz=.5,**kwargs): + """ + dataset generation + """ + if dataset.lower()=='3gauss': + y=np.floor((np.arange(n)*1.0/n*3))+1 + x=np.zeros((n,2)) + # class 1 + x[y==1,0]=-1.; x[y==1,1]=-1. + x[y==2,0]=-1.; x[y==2,1]=1. + x[y==3,0]=1. ; x[y==3,1]=0 + + x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) + x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) + + elif dataset.lower()=='3gauss2': + y=np.floor((np.arange(n)*1.0/n*4))+1 + x=np.zeros((n,2)) + y[y==4]=3 + # class 1 + x[y==1,0]=-1.; x[y==1,1]=-1. + x[y==2,0]=-1.; x[y==2,1]=1. + x[y==3,0]=1. ; x[y==3,1]=0 + + x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) + x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) +# elif dataset.lower()=='sinreg': +# +# x=np.random.rand(n,1) +# y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1) + + else: + x=0 + y=0 + print("unknown dataset") + + return x,y
\ No newline at end of file |