diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-11-03 17:29:16 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-03 17:29:16 +0100 |
commit | 9c6ac880d426b7577918b0c77bd74b3b01930ef6 (patch) | |
tree | 93b0899a0378a6fe8f063800091252d2c6ad9801 /ot/utils.py | |
parent | e1b67c641da3b3e497db6811af2c200022b10302 (diff) |
[MRG] Docs updates (#298)
* bregman docs
* sliced docs
* docs partial
* unbalanced docs
* stochastic docs
* plot docs
* datasets docs
* utils docs
* dr docs
* dr docs corrected
* smooth docs
* docs da
* pep8
* docs gromov
* more space after min and argmin
* docs lp
* bregman docs
* bregman docs mistake corrected
* pep8
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 118 |
1 files changed, 60 insertions, 58 deletions
diff --git a/ot/utils.py b/ot/utils.py index 0608aee..c878563 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -21,26 +21,26 @@ __time_tic_toc = time.time() def tic(): - """ Python implementation of Matlab tic() function """ + r""" Python implementation of Matlab tic() function """ global __time_tic_toc __time_tic_toc = time.time() def toc(message='Elapsed time : {} s'): - """ Python implementation of Matlab toc() function """ + r""" Python implementation of Matlab toc() function """ t = time.time() print(message.format(t - __time_tic_toc)) return t - __time_tic_toc def toq(): - """ Python implementation of Julia toc() function """ + r""" Python implementation of Julia toc() function """ t = time.time() return t - __time_tic_toc def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): - """Compute kernel matrix""" + r"""Compute kernel matrix""" nx = get_backend(x1, x2) @@ -50,13 +50,13 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): def laplacian(x): - """Compute Laplacian matrix""" + r"""Compute Laplacian matrix""" L = np.diag(np.sum(x, axis=0)) - x return L def list_to_array(*lst): - """ Convert a list if in numpy format """ + r""" Convert a list if in numpy format """ if len(lst) > 1: return [np.array(a) if isinstance(a, list) else a for a in lst] else: @@ -64,17 +64,18 @@ def list_to_array(*lst): def proj_simplex(v, z=1): - r""" compute the closest point (orthogonal projection) on the - generalized (n-1)-simplex of a vector v wrt. to the Euclidean + r"""Compute the closest point (orthogonal projection) on the + generalized `(n-1)`-simplex of a vector :math:`\mathbf{v}` wrt. to the Euclidean distance, thus solving: + .. math:: - \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + \mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2 - s.t. \gamma^T 1= z + s.t. \ \gamma^T \mathbf{1} = z - \gamma\geq 0 + \gamma \geq 0 - If v is a 2d array, compute all the projections wrt. axis 0 + If :math:`\mathbf{v}` is a 2d array, compute all the projections wrt. axis 0 .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -87,7 +88,7 @@ def proj_simplex(v, z=1): Returns ------- - h : ndarray, shape (n,d) + h : ndarray, shape (`n`, `d`) Array of projections on the simplex """ nx = get_backend(v) @@ -116,26 +117,24 @@ def proj_simplex(v, z=1): def unif(n): - """ return a uniform histogram of length n (simplex) + r""" + Return a uniform histogram of length `n` (simplex). Parameters ---------- - n : int number of bins in the histogram Returns ------- - h : np.array (n,) - histogram of length n such that h_i=1/n for all i - - + h : np.array (`n`,) + histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ return np.ones((n,)) / n def clean_zeros(a, b, M): - """ Remove all components with zeros weights in a and b + r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}` """ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) a2 = a[a > 0] @@ -144,8 +143,8 @@ def clean_zeros(a, b, M): def euclidean_distances(X, Y, squared=False): - """ - Considering the rows of X (and Y=X) as vectors, compute the + r""" + Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the distance matrix between each pair of vectors. .. note:: This function is backend-compatible and will work on arrays @@ -153,14 +152,14 @@ def euclidean_distances(X, Y, squared=False): Parameters ---------- - X : {array-like}, shape (n_samples_1, n_features) - Y : {array-like}, shape (n_samples_2, n_features) + X : array-like, shape (n_samples_1, n_features) + Y : array-like, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. Returns ------- - distances : {array}, shape (n_samples_1, n_samples_2) + distances : array-like, shape (`n_samples_1`, `n_samples_2`) """ nx = get_backend(X, Y) @@ -184,7 +183,7 @@ def euclidean_distances(X, Y, squared=False): def dist(x1, x2=None, metric='sqeuclidean', p=2): - """Compute distance between samples in x1 and x2 + r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -193,9 +192,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): ---------- x1 : array-like, shape (n1,d) - matrix with n1 samples of size d + matrix with `n1` samples of size `d` x2 : array-like, shape (n2,d), optional - matrix with n2 samples of size d (if None then x2=x1) + matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`) metric : str | callable, optional 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also accepts from the scipy.spatial.distance.cdist function : 'braycurtis', @@ -208,7 +207,7 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): Returns ------- - M : array-like, shape (n1, n2) + M : array-like, shape (`n1`, `n2`) distance matrix computed with given metric """ @@ -226,7 +225,7 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): def dist0(n, method='lin_square'): - """Compute standard cost matrices of size (n, n) for OT problems + r"""Compute standard cost matrices of size (`n`, `n`) for OT problems Parameters ---------- @@ -235,11 +234,11 @@ def dist0(n, method='lin_square'): method : str, optional Type of loss matrix chosen from: - * 'lin_square' : linear sampling between 0 and n-1, quadratic loss + * 'lin_square' : linear sampling between 0 and `n-1`, quadratic loss Returns ------- - M : ndarray, shape (n1,n2) + M : ndarray, shape (`n1`, `n2`) Distance matrix computed with given metric. """ res = 0 @@ -250,7 +249,7 @@ def dist0(n, method='lin_square'): def cost_normalization(C, norm=None): - """ Apply normalization to the loss matrix + r""" Apply normalization to the loss matrix Parameters ---------- @@ -262,7 +261,7 @@ def cost_normalization(C, norm=None): Returns ------- - C : ndarray, shape (n1, n2) + C : ndarray, shape (`n1`, `n2`) The input cost matrix normalized according to given norm. """ @@ -284,23 +283,23 @@ def cost_normalization(C, norm=None): def dots(*args): - """ dots function for multiple matrix multiply """ + r""" dots function for multiple matrix multiply """ return reduce(np.dot, args) def label_normalization(y, start=0): - """ Transform labels to start at a given value + r""" Transform labels to start at a given value Parameters ---------- y : array-like, shape (n, ) The vector of labels to be normalized. start : int - Desired value for the smallest label in y (default=0) + Desired value for the smallest label in :math:`\mathbf{y}` (default=0) Returns ------- - y : array-like, shape (n1, ) + y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ @@ -311,14 +310,14 @@ def label_normalization(y, start=0): def parmap(f, X, nprocs="default"): - """ paralell map for multiprocessing. + r""" parallel map for multiprocessing. The function has been deprecated and only performs a regular map. """ return list(map(f, X)) def check_params(**kwargs): - """check_params: check whether some parameters are missing + r"""check_params: check whether some parameters are missing """ missing_params = [] @@ -339,14 +338,14 @@ def check_params(**kwargs): def check_random_state(seed): - """Turn seed into a np.random.RandomState instance + r"""Turn `seed` into a np.random.RandomState instance Parameters ---------- seed : None | int | instance of RandomState - If seed is None, return the RandomState singleton used by np.random. - If seed is an int, return a new RandomState instance seeded with seed. - If seed is already a RandomState instance, return it. + If `seed` is None, return the RandomState singleton used by np.random. + If `seed` is an int, return a new RandomState instance seeded with `seed`. + If `seed` is already a RandomState instance, return it. Otherwise raise ValueError. """ if seed is None or seed is np.random: @@ -360,18 +359,21 @@ def check_random_state(seed): class deprecated(object): - """Decorator to mark a function or class as deprecated. + r"""Decorator to mark a function or class as deprecated. deprecated class from scikit-learn package https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py Issue a warning when the function is called/the class is instantiated and adds a warning to the docstring. The optional extra argument will be appended to the deprecation message - and the docstring. Note: to use this with the default value for extra, put - in an empty of parentheses: - >>> from ot.deprecation import deprecated # doctest: +SKIP - >>> @deprecated() # doctest: +SKIP - ... def some_function(): pass # doctest: +SKIP + and the docstring. + + .. note:: + To use this with the default value for extra, use empty parentheses: + + >>> from ot.deprecation import deprecated # doctest: +SKIP + >>> @deprecated() # doctest: +SKIP + ... def some_function(): pass # doctest: +SKIP Parameters ---------- @@ -386,7 +388,7 @@ class deprecated(object): self.extra = extra def __call__(self, obj): - """Call method + r"""Call method Parameters ---------- obj : object @@ -417,7 +419,7 @@ class deprecated(object): return cls def _decorate_fun(self, fun): - """Decorate function fun""" + r"""Decorate function fun""" msg = "Function %s is deprecated" % fun.__name__ if self.extra: @@ -443,7 +445,7 @@ class deprecated(object): def _is_deprecated(func): - """Helper to check if func is wraped by our deprecated decorator""" + r"""Helper to check if func is wraped by our deprecated decorator""" if sys.version_info < (3, 5): raise NotImplementedError("This is only available for python3.5 " "or above") @@ -457,7 +459,7 @@ def _is_deprecated(func): class BaseEstimator(object): - """Base class for most objects in POT + r"""Base class for most objects in POT Code adapted from sklearn BaseEstimator class @@ -470,7 +472,7 @@ class BaseEstimator(object): @classmethod def _get_param_names(cls): - """Get parameter names for the estimator""" + r"""Get parameter names for the estimator""" # fetch the constructor or the original constructor before # deprecation wrapping if any @@ -497,7 +499,7 @@ class BaseEstimator(object): return sorted([p.name for p in parameters]) def get_params(self, deep=True): - """Get parameters for this estimator. + r"""Get parameters for this estimator. Parameters ---------- @@ -534,7 +536,7 @@ class BaseEstimator(object): return out def set_params(self, **params): - """Set the parameters of this estimator. + r"""Set the parameters of this estimator. The method works on simple estimators as well as on nested objects (such as pipelines). The latter have parameters of the form @@ -574,7 +576,7 @@ class BaseEstimator(object): class UndefinedParameter(Exception): - """ + r""" Aim at raising an Exception when a undefined parameter is called """ |