diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2018-06-11 11:24:57 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-11 11:24:57 +0200 |
commit | 47730fc727c0f54e8459964d9208ad824e3f91da (patch) | |
tree | 48952166e88e602f1843bd15c0187c7d5ffb6cac /ot | |
parent | 4641ec5f2ddbff1a468afaf65741aecae44738cc (diff) | |
parent | 530dc93a60e9b81fb8d1b44680deea77dacf660b (diff) |
Merge branch 'master' into remove_otda_v05
Diffstat (limited to 'ot')
-rw-r--r-- | ot/datasets.py | 60 | ||||
-rw-r--r-- | ot/gromov.py | 1 | ||||
-rw-r--r-- | ot/utils.py | 20 |
3 files changed, 66 insertions, 15 deletions
diff --git a/ot/datasets.py b/ot/datasets.py index e4fe118..362a89b 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -9,9 +9,10 @@ Simple example datasets for OT import numpy as np import scipy as sp +from .utils import check_random_state, deprecated -def get_1D_gauss(n, m, s): +def make_1D_gauss(n, m, s): """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) Parameters @@ -36,19 +37,29 @@ def get_1D_gauss(n, m, s): return h / h.sum() -def get_2D_samples_gauss(n, m, sigma): +@deprecated() +def get_1D_gauss(n, m, sigma, random_state=None): + """ Deprecated see make_1D_gauss """ + return make_1D_gauss(n, m, sigma, random_state=None) + + +def make_2D_samples_gauss(n, m, sigma, random_state=None): """return n samples drawn from 2D gaussian N(m,sigma) Parameters ---------- n : int - number of bins in the histogram + number of samples to make m : np.array (2,) mean value of the gaussian distribution sigma : np.array (2,2) covariance matrix of the gaussian distribution - + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. Returns ------- @@ -56,17 +67,25 @@ def get_2D_samples_gauss(n, m, sigma): n samples drawn from N(m,sigma) """ + + generator = check_random_state(random_state) if np.isscalar(sigma): sigma = np.array([sigma, ]) if len(sigma) > 1: P = sp.linalg.sqrtm(sigma) - res = np.random.randn(n, 2).dot(P) + m + res = generator.randn(n, 2).dot(P) + m else: - res = np.random.randn(n, 2) * np.sqrt(sigma) + m + res = generator.randn(n, 2) * np.sqrt(sigma) + m return res -def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): +@deprecated() +def get_2D_samples_gauss(n, m, sigma, random_state=None): + """ Deprecated see make_2D_samples_gauss """ + return make_2D_samples_gauss(n, m, sigma, random_state=None) + + +def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): """ dataset generation for classification problems Parameters @@ -78,7 +97,11 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): number of training samples nz : float noise level (>0) - + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. Returns ------- @@ -88,6 +111,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): labels of the samples """ + + generator = check_random_state(random_state) + if dataset.lower() == '3gauss': y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 x = np.zeros((n, 2)) @@ -99,8 +125,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): x[y == 3, 0] = 1. x[y == 3, 1] = 0 - x[y != 3, :] += 1.5 * nz * np.random.randn(sum(y != 3), 2) - x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2) + x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) elif dataset.lower() == '3gauss2': y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 @@ -114,8 +140,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): x[y == 3, 0] = 2. x[y == 3, 1] = 0 - x[y != 3, :] += nz * np.random.randn(sum(y != 3), 2) - x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2) + x[y != 3, :] += nz * generator.randn(sum(y != 3), 2) + x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) elif dataset.lower() == 'gaussrot': rot = np.array( @@ -127,8 +153,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): n2 = np.sum(y == 2) x = np.zeros((n, 2)) - x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz) - x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz) + x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator) + x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator) x = x.dot(rot) @@ -138,3 +164,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs): print("unknown dataset") return x, y.astype(int) + + +@deprecated() +def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): + """ Deprecated see make_data_classif """ + return make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs) diff --git a/ot/gromov.py b/ot/gromov.py index 65b2e29..0278e99 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*-
"""
Gromov-Wasserstein transport method
-===================================
"""
diff --git a/ot/utils.py b/ot/utils.py index 17983f2..7dac283 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -225,6 +225,26 @@ def check_params(**kwargs): return check +def check_random_state(seed): + """Turn seed into a np.random.RandomState instance + Parameters + ---------- + seed : None | int | instance of RandomState + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + """ + if seed is None or seed is np.random: + return np.random.mtrand._rand + if isinstance(seed, (int, np.integer)): + return np.random.RandomState(seed) + if isinstance(seed, np.random.RandomState): + return seed + raise ValueError('{} cannot be used to seed a numpy.random.RandomState' + ' instance'.format(seed)) + + class deprecated(object): """Decorator to mark a function or class as deprecated. |