summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-03 17:29:16 +0100
committerGitHub <noreply@github.com>2021-11-03 17:29:16 +0100
commit9c6ac880d426b7577918b0c77bd74b3b01930ef6 (patch)
tree93b0899a0378a6fe8f063800091252d2c6ad9801 /ot/utils.py
parente1b67c641da3b3e497db6811af2c200022b10302 (diff)
[MRG] Docs updates (#298)
* bregman docs * sliced docs * docs partial * unbalanced docs * stochastic docs * plot docs * datasets docs * utils docs * dr docs * dr docs corrected * smooth docs * docs da * pep8 * docs gromov * more space after min and argmin * docs lp * bregman docs * bregman docs mistake corrected * pep8 Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py118
1 files changed, 60 insertions, 58 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 0608aee..c878563 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -21,26 +21,26 @@ __time_tic_toc = time.time()
def tic():
- """ Python implementation of Matlab tic() function """
+ r""" Python implementation of Matlab tic() function """
global __time_tic_toc
__time_tic_toc = time.time()
def toc(message='Elapsed time : {} s'):
- """ Python implementation of Matlab toc() function """
+ r""" Python implementation of Matlab toc() function """
t = time.time()
print(message.format(t - __time_tic_toc))
return t - __time_tic_toc
def toq():
- """ Python implementation of Julia toc() function """
+ r""" Python implementation of Julia toc() function """
t = time.time()
return t - __time_tic_toc
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
- """Compute kernel matrix"""
+ r"""Compute kernel matrix"""
nx = get_backend(x1, x2)
@@ -50,13 +50,13 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
def laplacian(x):
- """Compute Laplacian matrix"""
+ r"""Compute Laplacian matrix"""
L = np.diag(np.sum(x, axis=0)) - x
return L
def list_to_array(*lst):
- """ Convert a list if in numpy format """
+ r""" Convert a list if in numpy format """
if len(lst) > 1:
return [np.array(a) if isinstance(a, list) else a for a in lst]
else:
@@ -64,17 +64,18 @@ def list_to_array(*lst):
def proj_simplex(v, z=1):
- r""" compute the closest point (orthogonal projection) on the
- generalized (n-1)-simplex of a vector v wrt. to the Euclidean
+ r"""Compute the closest point (orthogonal projection) on the
+ generalized `(n-1)`-simplex of a vector :math:`\mathbf{v}` wrt. to the Euclidean
distance, thus solving:
+
.. math::
- \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2
+ \mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2
- s.t. \gamma^T 1= z
+ s.t. \ \gamma^T \mathbf{1} = z
- \gamma\geq 0
+ \gamma \geq 0
- If v is a 2d array, compute all the projections wrt. axis 0
+ If :math:`\mathbf{v}` is a 2d array, compute all the projections wrt. axis 0
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
@@ -87,7 +88,7 @@ def proj_simplex(v, z=1):
Returns
-------
- h : ndarray, shape (n,d)
+ h : ndarray, shape (`n`, `d`)
Array of projections on the simplex
"""
nx = get_backend(v)
@@ -116,26 +117,24 @@ def proj_simplex(v, z=1):
def unif(n):
- """ return a uniform histogram of length n (simplex)
+ r"""
+ Return a uniform histogram of length `n` (simplex).
Parameters
----------
-
n : int
number of bins in the histogram
Returns
-------
- h : np.array (n,)
- histogram of length n such that h_i=1/n for all i
-
-
+ h : np.array (`n`,)
+ histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
return np.ones((n,)) / n
def clean_zeros(a, b, M):
- """ Remove all components with zeros weights in a and b
+ r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}`
"""
M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
a2 = a[a > 0]
@@ -144,8 +143,8 @@ def clean_zeros(a, b, M):
def euclidean_distances(X, Y, squared=False):
- """
- Considering the rows of X (and Y=X) as vectors, compute the
+ r"""
+ Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
distance matrix between each pair of vectors.
.. note:: This function is backend-compatible and will work on arrays
@@ -153,14 +152,14 @@ def euclidean_distances(X, Y, squared=False):
Parameters
----------
- X : {array-like}, shape (n_samples_1, n_features)
- Y : {array-like}, shape (n_samples_2, n_features)
+ X : array-like, shape (n_samples_1, n_features)
+ Y : array-like, shape (n_samples_2, n_features)
squared : boolean, optional
Return squared Euclidean distances.
Returns
-------
- distances : {array}, shape (n_samples_1, n_samples_2)
+ distances : array-like, shape (`n_samples_1`, `n_samples_2`)
"""
nx = get_backend(X, Y)
@@ -184,7 +183,7 @@ def euclidean_distances(X, Y, squared=False):
def dist(x1, x2=None, metric='sqeuclidean', p=2):
- """Compute distance between samples in x1 and x2
+ r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
@@ -193,9 +192,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
----------
x1 : array-like, shape (n1,d)
- matrix with n1 samples of size d
+ matrix with `n1` samples of size `d`
x2 : array-like, shape (n2,d), optional
- matrix with n2 samples of size d (if None then x2=x1)
+ matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`)
metric : str | callable, optional
'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
@@ -208,7 +207,7 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
Returns
-------
- M : array-like, shape (n1, n2)
+ M : array-like, shape (`n1`, `n2`)
distance matrix computed with given metric
"""
@@ -226,7 +225,7 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2):
def dist0(n, method='lin_square'):
- """Compute standard cost matrices of size (n, n) for OT problems
+ r"""Compute standard cost matrices of size (`n`, `n`) for OT problems
Parameters
----------
@@ -235,11 +234,11 @@ def dist0(n, method='lin_square'):
method : str, optional
Type of loss matrix chosen from:
- * 'lin_square' : linear sampling between 0 and n-1, quadratic loss
+ * 'lin_square' : linear sampling between 0 and `n-1`, quadratic loss
Returns
-------
- M : ndarray, shape (n1,n2)
+ M : ndarray, shape (`n1`, `n2`)
Distance matrix computed with given metric.
"""
res = 0
@@ -250,7 +249,7 @@ def dist0(n, method='lin_square'):
def cost_normalization(C, norm=None):
- """ Apply normalization to the loss matrix
+ r""" Apply normalization to the loss matrix
Parameters
----------
@@ -262,7 +261,7 @@ def cost_normalization(C, norm=None):
Returns
-------
- C : ndarray, shape (n1, n2)
+ C : ndarray, shape (`n1`, `n2`)
The input cost matrix normalized according to given norm.
"""
@@ -284,23 +283,23 @@ def cost_normalization(C, norm=None):
def dots(*args):
- """ dots function for multiple matrix multiply """
+ r""" dots function for multiple matrix multiply """
return reduce(np.dot, args)
def label_normalization(y, start=0):
- """ Transform labels to start at a given value
+ r""" Transform labels to start at a given value
Parameters
----------
y : array-like, shape (n, )
The vector of labels to be normalized.
start : int
- Desired value for the smallest label in y (default=0)
+ Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
Returns
-------
- y : array-like, shape (n1, )
+ y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
@@ -311,14 +310,14 @@ def label_normalization(y, start=0):
def parmap(f, X, nprocs="default"):
- """ paralell map for multiprocessing.
+ r""" parallel map for multiprocessing.
The function has been deprecated and only performs a regular map.
"""
return list(map(f, X))
def check_params(**kwargs):
- """check_params: check whether some parameters are missing
+ r"""check_params: check whether some parameters are missing
"""
missing_params = []
@@ -339,14 +338,14 @@ def check_params(**kwargs):
def check_random_state(seed):
- """Turn seed into a np.random.RandomState instance
+ r"""Turn `seed` into a np.random.RandomState instance
Parameters
----------
seed : None | int | instance of RandomState
- If seed is None, return the RandomState singleton used by np.random.
- If seed is an int, return a new RandomState instance seeded with seed.
- If seed is already a RandomState instance, return it.
+ If `seed` is None, return the RandomState singleton used by np.random.
+ If `seed` is an int, return a new RandomState instance seeded with `seed`.
+ If `seed` is already a RandomState instance, return it.
Otherwise raise ValueError.
"""
if seed is None or seed is np.random:
@@ -360,18 +359,21 @@ def check_random_state(seed):
class deprecated(object):
- """Decorator to mark a function or class as deprecated.
+ r"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
Issue a warning when the function is called/the class is instantiated and
adds a warning to the docstring.
The optional extra argument will be appended to the deprecation message
- and the docstring. Note: to use this with the default value for extra, put
- in an empty of parentheses:
- >>> from ot.deprecation import deprecated # doctest: +SKIP
- >>> @deprecated() # doctest: +SKIP
- ... def some_function(): pass # doctest: +SKIP
+ and the docstring.
+
+ .. note::
+ To use this with the default value for extra, use empty parentheses:
+
+ >>> from ot.deprecation import deprecated # doctest: +SKIP
+ >>> @deprecated() # doctest: +SKIP
+ ... def some_function(): pass # doctest: +SKIP
Parameters
----------
@@ -386,7 +388,7 @@ class deprecated(object):
self.extra = extra
def __call__(self, obj):
- """Call method
+ r"""Call method
Parameters
----------
obj : object
@@ -417,7 +419,7 @@ class deprecated(object):
return cls
def _decorate_fun(self, fun):
- """Decorate function fun"""
+ r"""Decorate function fun"""
msg = "Function %s is deprecated" % fun.__name__
if self.extra:
@@ -443,7 +445,7 @@ class deprecated(object):
def _is_deprecated(func):
- """Helper to check if func is wraped by our deprecated decorator"""
+ r"""Helper to check if func is wraped by our deprecated decorator"""
if sys.version_info < (3, 5):
raise NotImplementedError("This is only available for python3.5 "
"or above")
@@ -457,7 +459,7 @@ def _is_deprecated(func):
class BaseEstimator(object):
- """Base class for most objects in POT
+ r"""Base class for most objects in POT
Code adapted from sklearn BaseEstimator class
@@ -470,7 +472,7 @@ class BaseEstimator(object):
@classmethod
def _get_param_names(cls):
- """Get parameter names for the estimator"""
+ r"""Get parameter names for the estimator"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
@@ -497,7 +499,7 @@ class BaseEstimator(object):
return sorted([p.name for p in parameters])
def get_params(self, deep=True):
- """Get parameters for this estimator.
+ r"""Get parameters for this estimator.
Parameters
----------
@@ -534,7 +536,7 @@ class BaseEstimator(object):
return out
def set_params(self, **params):
- """Set the parameters of this estimator.
+ r"""Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects
(such as pipelines). The latter have parameters of the form
@@ -574,7 +576,7 @@ class BaseEstimator(object):
class UndefinedParameter(Exception):
- """
+ r"""
Aim at raising an Exception when a undefined parameter is called
"""