summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index 588f501..c750812 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -108,9 +108,9 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
elif dataset.lower()=='gaussrot' :
- rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
- m1=np.array([-1,-1])
- m2=np.array([1,1])
+ rot=np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]])
+ m1=np.array([-1,1])
+ m2=np.array([1,-1])
y=np.floor((np.arange(n)*1.0/n*2))+1
n1=np.sum(y==1)
n2=np.sum(y==2)