summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-03 17:07:22 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-03 17:07:22 +0100
commit7e16b7a80f3a2896351262a02af27a60401b6a5e (patch)
tree47d2ba913bc5ddc726441d284638db9d7c2cff86 /ot/datasets.py
parent86b1c88eb0c2c43853bca38de96d2278cc90ceba (diff)
add mapping estimation with kernels (still debugging)
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index 588f501..c750812 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -108,9 +108,9 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
elif dataset.lower()=='gaussrot' :
- rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
- m1=np.array([-1,-1])
- m2=np.array([1,1])
+ rot=np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]])
+ m1=np.array([-1,1])
+ m2=np.array([1,-1])
y=np.floor((np.arange(n)*1.0/n*2))+1
n1=np.sum(y==1)
n2=np.sum(y==2)