summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-24 17:12:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-24 17:12:01 +0200
commit041f9bca2e89d1d111b553159aed862291630b00 (patch)
tree8ce24333e2f916592e7667fa8e0cbc8ebc6dda4a /ot/datasets.py
parent176ff069483b9ba630af8a00ce5edc104168c0a2 (diff)
add domain adaptation
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py39
1 files changed, 38 insertions, 1 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index cebfdac..edc29a9 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -24,4 +24,41 @@ def get_2D_samples_gauss(n,m,sigma):
else:
res= np.random.randn(n,2)*np.sqrt(sigma)+m
return res
- \ No newline at end of file
+
+def get_data_classif(dataset,n,nz=.5,**kwargs):
+ """
+ dataset generation
+ """
+ if dataset.lower()=='3gauss':
+ y=np.floor((np.arange(n)*1.0/n*3))+1
+ x=np.zeros((n,2))
+ # class 1
+ x[y==1,0]=-1.; x[y==1,1]=-1.
+ x[y==2,0]=-1.; x[y==2,1]=1.
+ x[y==3,0]=1. ; x[y==3,1]=0
+
+ x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
+ x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
+
+ elif dataset.lower()=='3gauss2':
+ y=np.floor((np.arange(n)*1.0/n*4))+1
+ x=np.zeros((n,2))
+ y[y==4]=3
+ # class 1
+ x[y==1,0]=-1.; x[y==1,1]=-1.
+ x[y==2,0]=-1.; x[y==2,1]=1.
+ x[y==3,0]=1. ; x[y==3,1]=0
+
+ x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
+ x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
+# elif dataset.lower()=='sinreg':
+#
+# x=np.random.rand(n,1)
+# y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1)
+
+ else:
+ x=0
+ y=0
+ print("unknown dataset")
+
+ return x,y \ No newline at end of file