summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-12 23:52:57 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:08:02 +0200
commit37ca3142a4c8c808382f5eb1c23bf198c3e4610e (patch)
tree55c2b56581b8f0f607d8e0a50778267baa4960e3 /ot/datasets.py
parentda21b9888e77f7512727a4f50c60bd475e2c9606 (diff)
pep8
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py107
1 files changed, 56 insertions, 51 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index 7816833..4371a23 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -7,7 +7,7 @@ import numpy as np
import scipy as sp
-def get_1D_gauss(n,m,s):
+def get_1D_gauss(n, m, s):
"""return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
Parameters
@@ -27,12 +27,12 @@ def get_1D_gauss(n,m,s):
1D histogram for a gaussian distribution
"""
- x=np.arange(n,dtype=np.float64)
- h=np.exp(-(x-m)**2/(2*s**2))
- return h/h.sum()
+ x = np.arange(n, dtype=np.float64)
+ h = np.exp(-(x - m)**2 / (2 * s**2))
+ return h / h.sum()
-def get_2D_samples_gauss(n,m,sigma):
+def get_2D_samples_gauss(n, m, sigma):
"""return n samples drawn from 2D gaussian N(m,sigma)
Parameters
@@ -52,17 +52,17 @@ def get_2D_samples_gauss(n,m,sigma):
n samples drawn from N(m,sigma)
"""
- if np.isscalar(sigma):
- sigma=np.array([sigma,])
- if len(sigma)>1:
- P=sp.linalg.sqrtm(sigma)
- res= np.random.randn(n,2).dot(P)+m
+ if np.isscalar(sigma):
+ sigma = np.array([sigma, ])
+ if len(sigma) > 1:
+ P = sp.linalg.sqrtm(sigma)
+ res = np.random.randn(n, 2).dot(P) + m
else:
- res= np.random.randn(n,2)*np.sqrt(sigma)+m
+ res = np.random.randn(n, 2) * np.sqrt(sigma) + m
return res
-def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
+def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
""" dataset generation for classification problems
Parameters
@@ -84,48 +84,53 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
labels of the samples
"""
- if dataset.lower()=='3gauss':
- y=np.floor((np.arange(n)*1.0/n*3))+1
- x=np.zeros((n,2))
+ if dataset.lower() == '3gauss':
+ y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ x = np.zeros((n, 2))
# class 1
- x[y==1,0]=-1.; x[y==1,1]=-1.
- x[y==2,0]=-1.; x[y==2,1]=1.
- x[y==3,0]=1. ; x[y==3,1]=0
-
- x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2)
- x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
-
- elif dataset.lower()=='3gauss2':
- y=np.floor((np.arange(n)*1.0/n*3))+1
- x=np.zeros((n,2))
- y[y==4]=3
+ x[y == 1, 0] = -1.
+ x[y == 1, 1] = -1.
+ x[y == 2, 0] = -1.
+ x[y == 2, 1] = 1.
+ x[y == 3, 0] = 1.
+ x[y == 3, 1] = 0
+
+ x[y != 3, :] += 1.5 * nz * np.random.randn(sum(y != 3), 2)
+ x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2)
+
+ elif dataset.lower() == '3gauss2':
+ y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ x = np.zeros((n, 2))
+ y[y == 4] = 3
# class 1
- x[y==1,0]=-2.; x[y==1,1]=-2.
- x[y==2,0]=-2.; x[y==2,1]=2.
- x[y==3,0]=2. ; x[y==3,1]=0
-
- x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
- x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
-
- elif dataset.lower()=='gaussrot' :
- rot=np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]])
- m1=np.array([-1,1])
- m2=np.array([1,-1])
- y=np.floor((np.arange(n)*1.0/n*2))+1
- n1=np.sum(y==1)
- n2=np.sum(y==2)
- x=np.zeros((n,2))
-
- x[y==1,:]=get_2D_samples_gauss(n1,m1,nz)
- x[y==2,:]=get_2D_samples_gauss(n2,m2,nz)
-
- x=x.dot(rot)
-
-
+ x[y == 1, 0] = -2.
+ x[y == 1, 1] = -2.
+ x[y == 2, 0] = -2.
+ x[y == 2, 1] = 2.
+ x[y == 3, 0] = 2.
+ x[y == 3, 1] = 0
+
+ x[y != 3, :] += nz * np.random.randn(sum(y != 3), 2)
+ x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2)
+
+ elif dataset.lower() == 'gaussrot':
+ rot = np.array(
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]])
+ m1 = np.array([-1, 1])
+ m2 = np.array([1, -1])
+ y = np.floor((np.arange(n) * 1.0 / n * 2)) + 1
+ n1 = np.sum(y == 1)
+ n2 = np.sum(y == 2)
+ x = np.zeros((n, 2))
+
+ x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz)
+ x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz)
+
+ x = x.dot(rot)
else:
- x=np.array(0)
- y=np.array(0)
+ x = np.array(0)
+ y = np.array(0)
print("unknown dataset")
- return x,y.astype(int) \ No newline at end of file
+ return x, y.astype(int)