summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2018-06-10 22:59:27 +0200
committerGitHub <noreply@github.com>2018-06-10 22:59:27 +0200
commit530dc93a60e9b81fb8d1b44680deea77dacf660b (patch)
treea2ae2e21bf7fac9fe2fa368782d7df5ea9974342 /ot
parent90efa5a8b189214d1aeb81920b2bb04ce0c261ca (diff)
parent2b375f263ef88100b0321c8ef1b3605dfbb95b3d (diff)
Merge pull request #49 from rflamary/dataset_fun
Dataset functions + test/notebooks update
Diffstat (limited to 'ot')
-rw-r--r--ot/datasets.py60
-rw-r--r--ot/gromov.py1
-rw-r--r--ot/utils.py20
3 files changed, 66 insertions, 15 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index e4fe118..362a89b 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -9,9 +9,10 @@ Simple example datasets for OT
import numpy as np
import scipy as sp
+from .utils import check_random_state, deprecated
-def get_1D_gauss(n, m, s):
+def make_1D_gauss(n, m, s):
"""return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
Parameters
@@ -36,19 +37,29 @@ def get_1D_gauss(n, m, s):
return h / h.sum()
-def get_2D_samples_gauss(n, m, sigma):
+@deprecated()
+def get_1D_gauss(n, m, sigma, random_state=None):
+ """ Deprecated see make_1D_gauss """
+ return make_1D_gauss(n, m, sigma, random_state=None)
+
+
+def make_2D_samples_gauss(n, m, sigma, random_state=None):
"""return n samples drawn from 2D gaussian N(m,sigma)
Parameters
----------
n : int
- number of bins in the histogram
+ number of samples to make
m : np.array (2,)
mean value of the gaussian distribution
sigma : np.array (2,2)
covariance matrix of the gaussian distribution
-
+ random_state : int, RandomState instance or None, optional (default=None)
+ If int, random_state is the seed used by the random number generator;
+ If RandomState instance, random_state is the random number generator;
+ If None, the random number generator is the RandomState instance used
+ by `np.random`.
Returns
-------
@@ -56,17 +67,25 @@ def get_2D_samples_gauss(n, m, sigma):
n samples drawn from N(m,sigma)
"""
+
+ generator = check_random_state(random_state)
if np.isscalar(sigma):
sigma = np.array([sigma, ])
if len(sigma) > 1:
P = sp.linalg.sqrtm(sigma)
- res = np.random.randn(n, 2).dot(P) + m
+ res = generator.randn(n, 2).dot(P) + m
else:
- res = np.random.randn(n, 2) * np.sqrt(sigma) + m
+ res = generator.randn(n, 2) * np.sqrt(sigma) + m
return res
-def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
+@deprecated()
+def get_2D_samples_gauss(n, m, sigma, random_state=None):
+ """ Deprecated see make_2D_samples_gauss """
+ return make_2D_samples_gauss(n, m, sigma, random_state=None)
+
+
+def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
""" dataset generation for classification problems
Parameters
@@ -78,7 +97,11 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
number of training samples
nz : float
noise level (>0)
-
+ random_state : int, RandomState instance or None, optional (default=None)
+ If int, random_state is the seed used by the random number generator;
+ If RandomState instance, random_state is the random number generator;
+ If None, the random number generator is the RandomState instance used
+ by `np.random`.
Returns
-------
@@ -88,6 +111,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
labels of the samples
"""
+
+ generator = check_random_state(random_state)
+
if dataset.lower() == '3gauss':
y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
x = np.zeros((n, 2))
@@ -99,8 +125,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
x[y == 3, 0] = 1.
x[y == 3, 1] = 0
- x[y != 3, :] += 1.5 * nz * np.random.randn(sum(y != 3), 2)
- x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2)
+ x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2)
+ x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)
elif dataset.lower() == '3gauss2':
y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
@@ -114,8 +140,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
x[y == 3, 0] = 2.
x[y == 3, 1] = 0
- x[y != 3, :] += nz * np.random.randn(sum(y != 3), 2)
- x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2)
+ x[y != 3, :] += nz * generator.randn(sum(y != 3), 2)
+ x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)
elif dataset.lower() == 'gaussrot':
rot = np.array(
@@ -127,8 +153,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
n2 = np.sum(y == 2)
x = np.zeros((n, 2))
- x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz)
- x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz)
+ x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
+ x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)
x = x.dot(rot)
@@ -138,3 +164,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
print("unknown dataset")
return x, y.astype(int)
+
+
+@deprecated()
+def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
+ """ Deprecated see make_data_classif """
+ return make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs)
diff --git a/ot/gromov.py b/ot/gromov.py
index 65b2e29..0278e99 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
"""
Gromov-Wasserstein transport method
-===================================
"""
diff --git a/ot/utils.py b/ot/utils.py
index 17983f2..7dac283 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -225,6 +225,26 @@ def check_params(**kwargs):
return check
+def check_random_state(seed):
+ """Turn seed into a np.random.RandomState instance
+ Parameters
+ ----------
+ seed : None | int | instance of RandomState
+ If seed is None, return the RandomState singleton used by np.random.
+ If seed is an int, return a new RandomState instance seeded with seed.
+ If seed is already a RandomState instance, return it.
+ Otherwise raise ValueError.
+ """
+ if seed is None or seed is np.random:
+ return np.random.mtrand._rand
+ if isinstance(seed, (int, np.integer)):
+ return np.random.RandomState(seed)
+ if isinstance(seed, np.random.RandomState):
+ return seed
+ raise ValueError('{} cannot be used to seed a numpy.random.RandomState'
+ ' instance'.format(seed))
+
+
class deprecated(object):
"""Decorator to mark a function or class as deprecated.