summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py164
1 files changed, 164 insertions, 0 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
new file mode 100644
index 0000000..ba0cfd9
--- /dev/null
+++ b/ot/datasets.py
@@ -0,0 +1,164 @@
+"""
+Simple example datasets for OT
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import scipy as sp
+from .utils import check_random_state, deprecated
+
+
+def make_1D_gauss(n, m, s):
+ """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
+
+ Parameters
+ ----------
+ n : int
+ number of bins in the histogram
+ m : float
+ mean value of the gaussian distribution
+ s : float
+ standard deviaton of the gaussian distribution
+
+ Returns
+ -------
+ h : ndarray (n,)
+ 1D histogram for a gaussian distribution
+ """
+ x = np.arange(n, dtype=np.float64)
+ h = np.exp(-(x - m)**2 / (2 * s**2))
+ return h / h.sum()
+
+
+@deprecated()
+def get_1D_gauss(n, m, sigma):
+ """ Deprecated see make_1D_gauss """
+ return make_1D_gauss(n, m, sigma)
+
+
+def make_2D_samples_gauss(n, m, sigma, random_state=None):
+ """Return n samples drawn from 2D gaussian N(m,sigma)
+
+ Parameters
+ ----------
+ n : int
+ number of samples to make
+ m : ndarray, shape (2,)
+ mean value of the gaussian distribution
+ sigma : ndarray, shape (2, 2)
+ covariance matrix of the gaussian distribution
+ random_state : int, RandomState instance or None, optional (default=None)
+ If int, random_state is the seed used by the random number generator;
+ If RandomState instance, random_state is the random number generator;
+ If None, the random number generator is the RandomState instance used
+ by `np.random`.
+
+ Returns
+ -------
+ X : ndarray, shape (n, 2)
+ n samples drawn from N(m, sigma).
+ """
+
+ generator = check_random_state(random_state)
+ if np.isscalar(sigma):
+ sigma = np.array([sigma, ])
+ if len(sigma) > 1:
+ P = sp.linalg.sqrtm(sigma)
+ res = generator.randn(n, 2).dot(P) + m
+ else:
+ res = generator.randn(n, 2) * np.sqrt(sigma) + m
+ return res
+
+
+@deprecated()
+def get_2D_samples_gauss(n, m, sigma, random_state=None):
+ """ Deprecated see make_2D_samples_gauss """
+ return make_2D_samples_gauss(n, m, sigma, random_state=None)
+
+
+def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
+ """Dataset generation for classification problems
+
+ Parameters
+ ----------
+ dataset : str
+ type of classification problem (see code)
+ n : int
+ number of training samples
+ nz : float
+ noise level (>0)
+ random_state : int, RandomState instance or None, optional (default=None)
+ If int, random_state is the seed used by the random number generator;
+ If RandomState instance, random_state is the random number generator;
+ If None, the random number generator is the RandomState instance used
+ by `np.random`.
+
+ Returns
+ -------
+ X : ndarray, shape (n, d)
+ n observation of size d
+ y : ndarray, shape (n,)
+ labels of the samples.
+ """
+ generator = check_random_state(random_state)
+
+ if dataset.lower() == '3gauss':
+ y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ x = np.zeros((n, 2))
+ # class 1
+ x[y == 1, 0] = -1.
+ x[y == 1, 1] = -1.
+ x[y == 2, 0] = -1.
+ x[y == 2, 1] = 1.
+ x[y == 3, 0] = 1.
+ x[y == 3, 1] = 0
+
+ x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2)
+ x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)
+
+ elif dataset.lower() == '3gauss2':
+ y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ x = np.zeros((n, 2))
+ y[y == 4] = 3
+ # class 1
+ x[y == 1, 0] = -2.
+ x[y == 1, 1] = -2.
+ x[y == 2, 0] = -2.
+ x[y == 2, 1] = 2.
+ x[y == 3, 0] = 2.
+ x[y == 3, 1] = 0
+
+ x[y != 3, :] += nz * generator.randn(sum(y != 3), 2)
+ x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)
+
+ elif dataset.lower() == 'gaussrot':
+ rot = np.array(
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]])
+ m1 = np.array([-1, 1])
+ m2 = np.array([1, -1])
+ y = np.floor((np.arange(n) * 1.0 / n * 2)) + 1
+ n1 = np.sum(y == 1)
+ n2 = np.sum(y == 2)
+ x = np.zeros((n, 2))
+
+ x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
+ x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)
+
+ x = x.dot(rot)
+
+ else:
+ x = np.array(0)
+ y = np.array(0)
+ print("unknown dataset")
+
+ return x, y.astype(int)
+
+
+@deprecated()
+def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
+ """ Deprecated see make_data_classif """
+ return make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs)