summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-26 15:06:25 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-26 15:06:25 +0200
commit8d45086670f2666a316a94a33179658d20ac5ec2 (patch)
treedc4faa77fc8da9997e7424584379817149ea1013 /ot/datasets.py
parentab27b8387201e57173255b2d22ad3c8d5057ad8b (diff)
add domain adaptation demo
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index edc29a9..f22e345 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -37,17 +37,17 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
x[y==2,0]=-1.; x[y==2,1]=1.
x[y==3,0]=1. ; x[y==3,1]=0
- x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
+ x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2)
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
elif dataset.lower()=='3gauss2':
- y=np.floor((np.arange(n)*1.0/n*4))+1
+ y=np.floor((np.arange(n)*1.0/n*3))+1
x=np.zeros((n,2))
y[y==4]=3
# class 1
- x[y==1,0]=-1.; x[y==1,1]=-1.
- x[y==2,0]=-1.; x[y==2,1]=1.
- x[y==3,0]=1. ; x[y==3,1]=0
+ x[y==1,0]=-2.; x[y==1,1]=-2.
+ x[y==2,0]=-2.; x[y==2,1]=2.
+ x[y==3,0]=2. ; x[y==3,1]=0
x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)