diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 17:07:22 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 17:07:22 +0100 |
commit | 7e16b7a80f3a2896351262a02af27a60401b6a5e (patch) | |
tree | 47d2ba913bc5ddc726441d284638db9d7c2cff86 /ot/datasets.py | |
parent | 86b1c88eb0c2c43853bca38de96d2278cc90ceba (diff) |
add mapping estimation with kernels (still debugging)
Diffstat (limited to 'ot/datasets.py')
-rw-r--r-- | ot/datasets.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index 588f501..c750812 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -108,9 +108,9 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs): x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) elif dataset.lower()=='gaussrot' : - rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]) - m1=np.array([-1,-1]) - m2=np.array([1,1]) + rot=np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]]) + m1=np.array([-1,1]) + m2=np.array([1,-1]) y=np.floor((np.arange(n)*1.0/n*2))+1 n1=np.sum(y==1) n2=np.sum(y==2) |