From 8d45086670f2666a316a94a33179658d20ac5ec2 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 26 Oct 2016 15:06:25 +0200 Subject: add domain adaptation demo --- ot/da.py | 3 ++- ot/datasets.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 8ecd952..3ef9fc7 100644 --- a/ot/da.py +++ b/ot/da.py @@ -47,4 +47,5 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1): if idx_begin==-1: W[indices_labels[0],t]=np.min(all_maj) - return transp \ No newline at end of file + return transp + \ No newline at end of file diff --git a/ot/datasets.py b/ot/datasets.py index edc29a9..f22e345 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -37,17 +37,17 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==2,0]=-1.; x[y==2,1]=1. x[y==3,0]=1. ; x[y==3,1]=0 - x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) + x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) elif dataset.lower()=='3gauss2': - y=np.floor((np.arange(n)*1.0/n*4))+1 + y=np.floor((np.arange(n)*1.0/n*3))+1 x=np.zeros((n,2)) y[y==4]=3 # class 1 - x[y==1,0]=-1.; x[y==1,1]=-1. - x[y==2,0]=-1.; x[y==2,1]=1. - x[y==3,0]=1. ; x[y==3,1]=0 + x[y==1,0]=-2.; x[y==1,1]=-2. + x[y==2,0]=-2.; x[y==2,1]=2. + x[y==3,0]=2. ; x[y==3,1]=0 x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) -- cgit v1.2.3