summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py269
1 files changed, 163 insertions, 106 deletions
diff --git a/ot/utils.py b/ot/utils.py
index f9911a1..c878563 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -7,7 +7,6 @@ Various useful functions
#
# License: MIT License
-import multiprocessing
from functools import reduce
import time
@@ -15,67 +14,127 @@ import numpy as np
from scipy.spatial.distance import cdist
import sys
import warnings
-try:
- from inspect import signature
-except ImportError:
- from .externals.funcsigs import signature
+from inspect import signature
+from .backend import get_backend
__time_tic_toc = time.time()
def tic():
- """ Python implementation of Matlab tic() function """
+ r""" Python implementation of Matlab tic() function """
global __time_tic_toc
__time_tic_toc = time.time()
def toc(message='Elapsed time : {} s'):
- """ Python implementation of Matlab toc() function """
+ r""" Python implementation of Matlab toc() function """
t = time.time()
print(message.format(t - __time_tic_toc))
return t - __time_tic_toc
def toq():
- """ Python implementation of Julia toc() function """
+ r""" Python implementation of Julia toc() function """
t = time.time()
return t - __time_tic_toc
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
- """Compute kernel matrix"""
+ r"""Compute kernel matrix"""
+
+ nx = get_backend(x1, x2)
+
if method.lower() in ['gaussian', 'gauss', 'rbf']:
- K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ K = nx.exp(-dist(x1, x2) / (2 * sigma**2))
return K
def laplacian(x):
- """Compute Laplacian matrix"""
+ r"""Compute Laplacian matrix"""
L = np.diag(np.sum(x, axis=0)) - x
return L
-def unif(n):
- """ return a uniform histogram of length n (simplex)
+def list_to_array(*lst):
+ r""" Convert a list if in numpy format """
+ if len(lst) > 1:
+ return [np.array(a) if isinstance(a, list) else a for a in lst]
+ else:
+ return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
+
+
+def proj_simplex(v, z=1):
+ r"""Compute the closest point (orthogonal projection) on the
+ generalized `(n-1)`-simplex of a vector :math:`\mathbf{v}` wrt. to the Euclidean
+ distance, thus solving:
+
+ .. math::
+ \mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2
+
+ s.t. \ \gamma^T \mathbf{1} = z
+
+ \gamma \geq 0
+
+ If :math:`\mathbf{v}` is a 2d array, compute all the projections wrt. axis 0
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
+ v : {array-like}, shape (n, d)
+ z : int, optional
+ 'size' of the simplex (each vectors sum to z, 1 by default)
+
+ Returns
+ -------
+ h : ndarray, shape (`n`, `d`)
+ Array of projections on the simplex
+ """
+ nx = get_backend(v)
+ n = v.shape[0]
+ if v.ndim == 1:
+ d1 = 1
+ v = v[:, None]
+ else:
+ d1 = 0
+ d = v.shape[1]
+
+ # sort u in ascending order
+ u = nx.sort(v, axis=0)
+ # take the descending order
+ u = nx.flip(u, 0)
+ cssv = nx.cumsum(u, axis=0) - z
+ ind = nx.arange(n, type_as=v)[:, None] + 1
+ cond = u - cssv / ind > 0
+ rho = nx.sum(cond, 0)
+ theta = cssv[rho - 1, nx.arange(d)] / rho
+ w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v))
+ if d1:
+ return w[:, 0]
+ else:
+ return w
+
+
+def unif(n):
+ r"""
+ Return a uniform histogram of length `n` (simplex).
+ Parameters
+ ----------
n : int
number of bins in the histogram
Returns
-------
- h : np.array (n,)
- histogram of length n such that h_i=1/n for all i
-
-
+ h : np.array (`n`,)
+ histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
return np.ones((n,)) / n
def clean_zeros(a, b, M):
- """ Remove all components with zeros weights in a and b
+ r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}`
"""
M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
a2 = a[a > 0]
@@ -84,55 +143,71 @@ def clean_zeros(a, b, M):
def euclidean_distances(X, Y, squared=False):
- """
- Considering the rows of X (and Y=X) as vectors, compute the
+ r"""
+ Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
distance matrix between each pair of vectors.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
Parameters
----------
- X : {array-like}, shape (n_samples_1, n_features)
- Y : {array-like}, shape (n_samples_2, n_features)
+ X : array-like, shape (n_samples_1, n_features)
+ Y : array-like, shape (n_samples_2, n_features)
squared : boolean, optional
Return squared Euclidean distances.
+
Returns
-------
- distances : {array}, shape (n_samples_1, n_samples_2)
+ distances : array-like, shape (`n_samples_1`, `n_samples_2`)
"""
- XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
- YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
- distances = np.dot(X, Y.T)
- distances *= -2
- distances += XX
- distances += YY
- np.maximum(distances, 0, out=distances)
+
+ nx = get_backend(X, Y)
+
+ a2 = nx.einsum('ij,ij->i', X, X)
+ b2 = nx.einsum('ij,ij->i', Y, Y)
+
+ c = -2 * nx.dot(X, Y.T)
+ c += a2[:, None]
+ c += b2[None, :]
+
+ c = nx.maximum(c, 0)
+
+ if not squared:
+ c = nx.sqrt(c)
+
if X is Y:
- # Ensure that distances between vectors and themselves are set to 0.0.
- # This may not be the case due to floating point rounding errors.
- distances.flat[::distances.shape[0] + 1] = 0.0
- return distances if squared else np.sqrt(distances, out=distances)
+ c = c * (1 - nx.eye(X.shape[0], type_as=c))
+
+ return c
+
+def dist(x1, x2=None, metric='sqeuclidean', p=2):
+ r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
-def dist(x1, x2=None, metric='sqeuclidean'):
- """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
- x1 : ndarray, shape (n1,d)
- matrix with n1 samples of size d
- x2 : array, shape (n2,d), optional
- matrix with n2 samples of size d (if None then x2=x1)
+ x1 : array-like, shape (n1,d)
+ matrix with `n1` samples of size `d`
+ x2 : array-like, shape (n2,d), optional
+ matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`)
metric : str | callable, optional
- Name of the metric to be computed (full list in the doc of scipy), If a string,
- the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
- 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
- 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
+ accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
+ 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns
-------
- M : np.array (n1,n2)
+ M : array-like, shape (`n1`, `n2`)
distance matrix computed with given metric
"""
@@ -140,11 +215,17 @@ def dist(x1, x2=None, metric='sqeuclidean'):
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
- return cdist(x1, x2, metric=metric)
+ elif metric == "euclidean":
+ return euclidean_distances(x1, x2, squared=False)
+ else:
+ if not get_backend(x1, x2).__name__ == 'numpy':
+ raise NotImplementedError()
+ else:
+ return cdist(x1, x2, metric=metric, p=p)
def dist0(n, method='lin_square'):
- """Compute standard cost matrices of size (n, n) for OT problems
+ r"""Compute standard cost matrices of size (`n`, `n`) for OT problems
Parameters
----------
@@ -153,11 +234,11 @@ def dist0(n, method='lin_square'):
method : str, optional
Type of loss matrix chosen from:
- * 'lin_square' : linear sampling between 0 and n-1, quadratic loss
+ * 'lin_square' : linear sampling between 0 and `n-1`, quadratic loss
Returns
-------
- M : ndarray, shape (n1,n2)
+ M : ndarray, shape (`n1`, `n2`)
Distance matrix computed with given metric.
"""
res = 0
@@ -168,7 +249,7 @@ def dist0(n, method='lin_square'):
def cost_normalization(C, norm=None):
- """ Apply normalization to the loss matrix
+ r""" Apply normalization to the loss matrix
Parameters
----------
@@ -180,7 +261,7 @@ def cost_normalization(C, norm=None):
Returns
-------
- C : ndarray, shape (n1, n2)
+ C : ndarray, shape (`n1`, `n2`)
The input cost matrix normalized according to given norm.
"""
@@ -202,23 +283,23 @@ def cost_normalization(C, norm=None):
def dots(*args):
- """ dots function for multiple matrix multiply """
+ r""" dots function for multiple matrix multiply """
return reduce(np.dot, args)
def label_normalization(y, start=0):
- """ Transform labels to start at a given value
+ r""" Transform labels to start at a given value
Parameters
----------
y : array-like, shape (n, )
The vector of labels to be normalized.
start : int
- Desired value for the smallest label in y (default=0)
+ Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
Returns
-------
- y : array-like, shape (n1, )
+ y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
@@ -228,42 +309,15 @@ def label_normalization(y, start=0):
return y
-def fun(f, q_in, q_out):
- """ Utility function for parmap with no serializing problems """
- while True:
- i, x = q_in.get()
- if i is None:
- break
- q_out.put((i, f(x)))
-
-
-def parmap(f, X, nprocs=multiprocessing.cpu_count()):
- """ paralell map for multiprocessing (only map on windows)"""
-
- if not sys.platform.endswith('win32'):
-
- q_in = multiprocessing.Queue(1)
- q_out = multiprocessing.Queue()
-
- proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
- for _ in range(nprocs)]
- for p in proc:
- p.daemon = True
- p.start()
-
- sent = [q_in.put((i, x)) for i, x in enumerate(X)]
- [q_in.put((None, None)) for _ in range(nprocs)]
- res = [q_out.get() for _ in range(len(sent))]
-
- [p.join() for p in proc]
-
- return [x for i, x in sorted(res)]
- else:
- return list(map(f, X))
+def parmap(f, X, nprocs="default"):
+ r""" parallel map for multiprocessing.
+ The function has been deprecated and only performs a regular map.
+ """
+ return list(map(f, X))
def check_params(**kwargs):
- """check_params: check whether some parameters are missing
+ r"""check_params: check whether some parameters are missing
"""
missing_params = []
@@ -284,14 +338,14 @@ def check_params(**kwargs):
def check_random_state(seed):
- """Turn seed into a np.random.RandomState instance
+ r"""Turn `seed` into a np.random.RandomState instance
Parameters
----------
seed : None | int | instance of RandomState
- If seed is None, return the RandomState singleton used by np.random.
- If seed is an int, return a new RandomState instance seeded with seed.
- If seed is already a RandomState instance, return it.
+ If `seed` is None, return the RandomState singleton used by np.random.
+ If `seed` is an int, return a new RandomState instance seeded with `seed`.
+ If `seed` is already a RandomState instance, return it.
Otherwise raise ValueError.
"""
if seed is None or seed is np.random:
@@ -305,18 +359,21 @@ def check_random_state(seed):
class deprecated(object):
- """Decorator to mark a function or class as deprecated.
+ r"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
Issue a warning when the function is called/the class is instantiated and
adds a warning to the docstring.
The optional extra argument will be appended to the deprecation message
- and the docstring. Note: to use this with the default value for extra, put
- in an empty of parentheses:
- >>> from ot.deprecation import deprecated # doctest: +SKIP
- >>> @deprecated() # doctest: +SKIP
- ... def some_function(): pass # doctest: +SKIP
+ and the docstring.
+
+ .. note::
+ To use this with the default value for extra, use empty parentheses:
+
+ >>> from ot.deprecation import deprecated # doctest: +SKIP
+ >>> @deprecated() # doctest: +SKIP
+ ... def some_function(): pass # doctest: +SKIP
Parameters
----------
@@ -331,7 +388,7 @@ class deprecated(object):
self.extra = extra
def __call__(self, obj):
- """Call method
+ r"""Call method
Parameters
----------
obj : object
@@ -362,7 +419,7 @@ class deprecated(object):
return cls
def _decorate_fun(self, fun):
- """Decorate function fun"""
+ r"""Decorate function fun"""
msg = "Function %s is deprecated" % fun.__name__
if self.extra:
@@ -388,7 +445,7 @@ class deprecated(object):
def _is_deprecated(func):
- """Helper to check if func is wraped by our deprecated decorator"""
+ r"""Helper to check if func is wraped by our deprecated decorator"""
if sys.version_info < (3, 5):
raise NotImplementedError("This is only available for python3.5 "
"or above")
@@ -402,7 +459,7 @@ def _is_deprecated(func):
class BaseEstimator(object):
- """Base class for most objects in POT
+ r"""Base class for most objects in POT
Code adapted from sklearn BaseEstimator class
@@ -415,7 +472,7 @@ class BaseEstimator(object):
@classmethod
def _get_param_names(cls):
- """Get parameter names for the estimator"""
+ r"""Get parameter names for the estimator"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
@@ -442,7 +499,7 @@ class BaseEstimator(object):
return sorted([p.name for p in parameters])
def get_params(self, deep=True):
- """Get parameters for this estimator.
+ r"""Get parameters for this estimator.
Parameters
----------
@@ -479,7 +536,7 @@ class BaseEstimator(object):
return out
def set_params(self, **params):
- """Set the parameters of this estimator.
+ r"""Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects
(such as pipelines). The latter have parameters of the form
@@ -519,7 +576,7 @@ class BaseEstimator(object):
class UndefinedParameter(Exception):
- """
+ r"""
Aim at raising an Exception when a undefined parameter is called
"""