diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-26 15:06:25 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-26 15:06:25 +0200 |
commit | 8d45086670f2666a316a94a33179658d20ac5ec2 (patch) | |
tree | dc4faa77fc8da9997e7424584379817149ea1013 /ot/datasets.py | |
parent | ab27b8387201e57173255b2d22ad3c8d5057ad8b (diff) |
add domain adaptation demo
Diffstat (limited to 'ot/datasets.py')
-rw-r--r-- | ot/datasets.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index edc29a9..f22e345 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -37,17 +37,17 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==2,0]=-1.; x[y==2,1]=1. x[y==3,0]=1. ; x[y==3,1]=0 - x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) + x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) elif dataset.lower()=='3gauss2': - y=np.floor((np.arange(n)*1.0/n*4))+1 + y=np.floor((np.arange(n)*1.0/n*3))+1 x=np.zeros((n,2)) y[y==4]=3 # class 1 - x[y==1,0]=-1.; x[y==1,1]=-1. - x[y==2,0]=-1.; x[y==2,1]=1. - x[y==3,0]=1. ; x[y==3,1]=0 + x[y==1,0]=-2.; x[y==1,1]=-2. + x[y==2,0]=-2.; x[y==2,1]=2. + x[y==3,0]=2. ; x[y==3,1]=0 x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) |