summaryrefslogtreecommitdiff
path: root/ot/datasets.py
blob: f22e345315a9c32e5a503a7ce8fb3a485792869d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Simple example datasets for OT
"""


import numpy as np
import scipy as sp


def get_1D_gauss(n,m,s):
    "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) "
    x=np.arange(n,dtype=np.float64)
    h=np.exp(-(x-m)**2/(2*s^2))
    return h/h.sum()
    
    
def get_2D_samples_gauss(n,m,sigma):
    "return samples from 2D gaussian (n samples, mean m and cov sigma) "
    if  np.isscalar(sigma):
        sigma=np.array([sigma,])
    if len(sigma)>1:
        P=sp.linalg.sqrtm(sigma)
        res= np.random.randn(n,2).dot(P)+m
    else:
        res= np.random.randn(n,2)*np.sqrt(sigma)+m
    return res

def get_data_classif(dataset,n,nz=.5,**kwargs):
    """
    dataset generation
    """
    if dataset.lower()=='3gauss':
        y=np.floor((np.arange(n)*1.0/n*3))+1
        x=np.zeros((n,2))
        # class 1
        x[y==1,0]=-1.; x[y==1,1]=-1.
        x[y==2,0]=-1.; x[y==2,1]=1.
        x[y==3,0]=1. ; x[y==3,1]=0
        
        x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2)
        x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
        
    elif dataset.lower()=='3gauss2':
        y=np.floor((np.arange(n)*1.0/n*3))+1
        x=np.zeros((n,2))
        y[y==4]=3
        # class 1
        x[y==1,0]=-2.; x[y==1,1]=-2.
        x[y==2,0]=-2.; x[y==2,1]=2.
        x[y==3,0]=2. ; x[y==3,1]=0
        
        x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
        x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)   
#    elif dataset.lower()=='sinreg':
#        
#        x=np.random.rand(n,1)
#        y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1) 
         
    else:
        x=0
        y=0
        print("unknown dataset")
    
    return x,y