diff options
Diffstat (limited to 'ot/datasets.py')
-rw-r--r-- | ot/datasets.py | 111 |
1 files changed, 60 insertions, 51 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index 7816833..e4fe118 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -2,12 +2,16 @@ Simple example datasets for OT """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import numpy as np import scipy as sp -def get_1D_gauss(n,m,s): +def get_1D_gauss(n, m, s): """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) Parameters @@ -27,12 +31,12 @@ def get_1D_gauss(n,m,s): 1D histogram for a gaussian distribution """ - x=np.arange(n,dtype=np.float64) - h=np.exp(-(x-m)**2/(2*s**2)) - return h/h.sum() + x = np.arange(n, dtype=np.float64) + h = np.exp(-(x - m)**2 / (2 * s**2)) + return h / h.sum() -def get_2D_samples_gauss(n,m,sigma): +def get_2D_samples_gauss(n, m, sigma): """return n samples drawn from 2D gaussian N(m,sigma) Parameters @@ -52,17 +56,17 @@ def get_2D_samples_gauss(n,m,sigma): n samples drawn from N(m,sigma) """ - if np.isscalar(sigma): - sigma=np.array([sigma,]) - if len(sigma)>1: - P=sp.linalg.sqrtm(sigma) - res= np.random.randn(n,2).dot(P)+m + if np.isscalar(sigma): + sigma = np.array([sigma, ]) + if len(sigma) > 1: + P = sp.linalg.sqrtm(sigma) + res = np.random.randn(n, 2).dot(P) + m else: - res= np.random.randn(n,2)*np.sqrt(sigma)+m + res = np.random.randn(n, 2) * np.sqrt(sigma) + m return res -def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs): +def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): """ dataset generation for classification problems Parameters @@ -84,48 +88,53 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs): labels of the samples """ - if dataset.lower()=='3gauss': - y=np.floor((np.arange(n)*1.0/n*3))+1 - x=np.zeros((n,2)) + if dataset.lower() == '3gauss': + y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 + x = np.zeros((n, 2)) # class 1 - x[y==1,0]=-1.; x[y==1,1]=-1. - x[y==2,0]=-1.; x[y==2,1]=1. - x[y==3,0]=1. ; x[y==3,1]=0 - - x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2) - x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) - - elif dataset.lower()=='3gauss2': - y=np.floor((np.arange(n)*1.0/n*3))+1 - x=np.zeros((n,2)) - y[y==4]=3 + x[y == 1, 0] = -1. + x[y == 1, 1] = -1. + x[y == 2, 0] = -1. + x[y == 2, 1] = 1. + x[y == 3, 0] = 1. + x[y == 3, 1] = 0 + + x[y != 3, :] += 1.5 * nz * np.random.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2) + + elif dataset.lower() == '3gauss2': + y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 + x = np.zeros((n, 2)) + y[y == 4] = 3 # class 1 - x[y==1,0]=-2.; x[y==1,1]=-2. - x[y==2,0]=-2.; x[y==2,1]=2. - x[y==3,0]=2. ; x[y==3,1]=0 - - x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) - x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) - - elif dataset.lower()=='gaussrot' : - rot=np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]]) - m1=np.array([-1,1]) - m2=np.array([1,-1]) - y=np.floor((np.arange(n)*1.0/n*2))+1 - n1=np.sum(y==1) - n2=np.sum(y==2) - x=np.zeros((n,2)) - - x[y==1,:]=get_2D_samples_gauss(n1,m1,nz) - x[y==2,:]=get_2D_samples_gauss(n2,m2,nz) - - x=x.dot(rot) - - + x[y == 1, 0] = -2. + x[y == 1, 1] = -2. + x[y == 2, 0] = -2. + x[y == 2, 1] = 2. + x[y == 3, 0] = 2. + x[y == 3, 1] = 0 + + x[y != 3, :] += nz * np.random.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2) + + elif dataset.lower() == 'gaussrot': + rot = np.array( + [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]]) + m1 = np.array([-1, 1]) + m2 = np.array([1, -1]) + y = np.floor((np.arange(n) * 1.0 / n * 2)) + 1 + n1 = np.sum(y == 1) + n2 = np.sum(y == 2) + x = np.zeros((n, 2)) + + x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz) + x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz) + + x = x.dot(rot) else: - x=np.array(0) - y=np.array(0) + x = np.array(0) + y = np.array(0) print("unknown dataset") - return x,y.astype(int)
\ No newline at end of file + return x, y.astype(int) |