summaryrefslogtreecommitdiff
path: root/ot/datasets.py
blob: ba0cfd9585fe345c732c0a80fd426bd42381b767 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Simple example datasets for OT
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#
# License: MIT License


import numpy as np
import scipy as sp
from .utils import check_random_state, deprecated


def make_1D_gauss(n, m, s):
    """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)

    Parameters
    ----------
    n : int
        number of bins in the histogram
    m : float
        mean value of the gaussian distribution
    s : float
        standard deviaton of the gaussian distribution

    Returns
    -------
    h : ndarray (n,)
        1D histogram for a gaussian distribution
    """
    x = np.arange(n, dtype=np.float64)
    h = np.exp(-(x - m)**2 / (2 * s**2))
    return h / h.sum()


@deprecated()
def get_1D_gauss(n, m, sigma):
    """ Deprecated see  make_1D_gauss   """
    return make_1D_gauss(n, m, sigma)


def make_2D_samples_gauss(n, m, sigma, random_state=None):
    """Return n samples drawn from 2D gaussian N(m,sigma)

    Parameters
    ----------
    n : int
        number of samples to make
    m : ndarray, shape (2,)
        mean value of the gaussian distribution
    sigma : ndarray, shape (2, 2)
        covariance matrix of the gaussian distribution
    random_state : int, RandomState instance or None, optional (default=None)
        If int, random_state is the seed used by the random number generator;
        If RandomState instance, random_state is the random number generator;
        If None, the random number generator is the RandomState instance used
        by `np.random`.

    Returns
    -------
    X : ndarray, shape (n, 2)
        n samples drawn from N(m, sigma).
    """

    generator = check_random_state(random_state)
    if np.isscalar(sigma):
        sigma = np.array([sigma, ])
    if len(sigma) > 1:
        P = sp.linalg.sqrtm(sigma)
        res = generator.randn(n, 2).dot(P) + m
    else:
        res = generator.randn(n, 2) * np.sqrt(sigma) + m
    return res


@deprecated()
def get_2D_samples_gauss(n, m, sigma, random_state=None):
    """ Deprecated see  make_2D_samples_gauss   """
    return make_2D_samples_gauss(n, m, sigma, random_state=None)


def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
    """Dataset generation for classification problems

    Parameters
    ----------
    dataset : str
        type of classification problem (see code)
    n : int
        number of training samples
    nz : float
        noise level (>0)
    random_state : int, RandomState instance or None, optional (default=None)
        If int, random_state is the seed used by the random number generator;
        If RandomState instance, random_state is the random number generator;
        If None, the random number generator is the RandomState instance used
        by `np.random`.

    Returns
    -------
    X : ndarray, shape (n, d)
        n observation of size d
    y : ndarray, shape (n,)
        labels of the samples.
    """
    generator = check_random_state(random_state)

    if dataset.lower() == '3gauss':
        y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
        x = np.zeros((n, 2))
        # class 1
        x[y == 1, 0] = -1.
        x[y == 1, 1] = -1.
        x[y == 2, 0] = -1.
        x[y == 2, 1] = 1.
        x[y == 3, 0] = 1.
        x[y == 3, 1] = 0

        x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2)
        x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)

    elif dataset.lower() == '3gauss2':
        y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
        x = np.zeros((n, 2))
        y[y == 4] = 3
        # class 1
        x[y == 1, 0] = -2.
        x[y == 1, 1] = -2.
        x[y == 2, 0] = -2.
        x[y == 2, 1] = 2.
        x[y == 3, 0] = 2.
        x[y == 3, 1] = 0

        x[y != 3, :] += nz * generator.randn(sum(y != 3), 2)
        x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)

    elif dataset.lower() == 'gaussrot':
        rot = np.array(
            [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]])
        m1 = np.array([-1, 1])
        m2 = np.array([1, -1])
        y = np.floor((np.arange(n) * 1.0 / n * 2)) + 1
        n1 = np.sum(y == 1)
        n2 = np.sum(y == 2)
        x = np.zeros((n, 2))

        x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
        x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)

        x = x.dot(rot)

    else:
        x = np.array(0)
        y = np.array(0)
        print("unknown dataset")

    return x, y.astype(int)


@deprecated()
def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
    """ Deprecated see  make_data_classif   """
    return make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs)