diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/da.py | 3 | ||||
-rw-r--r-- | ot/datasets.py | 10 |
2 files changed, 7 insertions, 6 deletions
@@ -47,4 +47,5 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1): if idx_begin==-1: W[indices_labels[0],t]=np.min(all_maj) - return transp
\ No newline at end of file + return transp +
\ No newline at end of file diff --git a/ot/datasets.py b/ot/datasets.py index edc29a9..f22e345 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -37,17 +37,17 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==2,0]=-1.; x[y==2,1]=1. x[y==3,0]=1. ; x[y==3,1]=0 - x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) + x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) elif dataset.lower()=='3gauss2': - y=np.floor((np.arange(n)*1.0/n*4))+1 + y=np.floor((np.arange(n)*1.0/n*3))+1 x=np.zeros((n,2)) y[y==4]=3 # class 1 - x[y==1,0]=-1.; x[y==1,1]=-1. - x[y==2,0]=-1.; x[y==2,1]=1. - x[y==3,0]=1. ; x[y==3,1]=0 + x[y==1,0]=-2.; x[y==1,1]=-2. + x[y==2,0]=-2.; x[y==2,1]=2. + x[y==3,0]=2. ; x[y==3,1]=0 x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) |