summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /ot/utils.py
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py128
1 files changed, 107 insertions, 21 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 544c569..4dac0c5 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
+from .backend import get_backend
__time_tic_toc = time.time()
@@ -41,8 +42,11 @@ def toq():
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
"""Compute kernel matrix"""
+
+ nx = get_backend(x1, x2)
+
if method.lower() in ['gaussian', 'gauss', 'rbf']:
- K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ K = nx.exp(-dist(x1, x2) / (2 * sigma**2))
return K
@@ -52,6 +56,66 @@ def laplacian(x):
return L
+def list_to_array(*lst):
+ """ Convert a list if in numpy format """
+ if len(lst) > 1:
+ return [np.array(a) if isinstance(a, list) else a for a in lst]
+ else:
+ return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
+
+
+def proj_simplex(v, z=1):
+ r""" compute the closest point (orthogonal projection) on the
+ generalized (n-1)-simplex of a vector v wrt. to the Euclidean
+ distance, thus solving:
+ .. math::
+ \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2
+
+ s.t. \gamma^T 1= z
+
+ \gamma\geq 0
+
+ If v is a 2d array, compute all the projections wrt. axis 0
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ Parameters
+ ----------
+ v : {array-like}, shape (n, d)
+ z : int, optional
+ 'size' of the simplex (each vectors sum to z, 1 by default)
+
+ Returns
+ -------
+ h : ndarray, shape (n,d)
+ Array of projections on the simplex
+ """
+ nx = get_backend(v)
+ n = v.shape[0]
+ if v.ndim == 1:
+ d1 = 1
+ v = v[:, None]
+ else:
+ d1 = 0
+ d = v.shape[1]
+
+ # sort u in ascending order
+ u = nx.sort(v, axis=0)
+ # take the descending order
+ u = nx.flip(u, 0)
+ cssv = nx.cumsum(u, axis=0) - z
+ ind = nx.arange(n, type_as=v)[:, None] + 1
+ cond = u - cssv / ind > 0
+ rho = nx.sum(cond, 0)
+ theta = cssv[rho - 1, nx.arange(d)] / rho
+ w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v))
+ if d1:
+ return w[:, 0]
+ else:
+ return w
+
+
def unif(n):
""" return a uniform histogram of length n (simplex)
@@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False):
"""
Considering the rows of X (and Y=X) as vectors, compute the
distance matrix between each pair of vectors.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
Parameters
----------
X : {array-like}, shape (n_samples_1, n_features)
Y : {array-like}, shape (n_samples_2, n_features)
squared : boolean, optional
Return squared Euclidean distances.
+
Returns
-------
distances : {array}, shape (n_samples_1, n_samples_2)
"""
- XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
- YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
- distances = np.dot(X, Y.T)
- distances *= -2
- distances += XX
- distances += YY
- np.maximum(distances, 0, out=distances)
+
+ nx = get_backend(X, Y)
+
+ a2 = nx.einsum('ij,ij->i', X, X)
+ b2 = nx.einsum('ij,ij->i', Y, Y)
+
+ c = -2 * nx.dot(X, Y.T)
+ c += a2[:, None]
+ c += b2[None, :]
+
+ c = nx.maximum(c, 0)
+
+ if not squared:
+ c = nx.sqrt(c)
+
if X is Y:
- # Ensure that distances between vectors and themselves are set to 0.0.
- # This may not be the case due to floating point rounding errors.
- distances.flat[::distances.shape[0] + 1] = 0.0
- return distances if squared else np.sqrt(distances, out=distances)
+ c = c * (1 - nx.eye(X.shape[0], type_as=c))
+
+ return c
def dist(x1, x2=None, metric='sqeuclidean'):
- """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+ """Compute distance between samples in x1 and x2
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
- x1 : ndarray, shape (n1,d)
+ x1 : array-like, shape (n1,d)
matrix with n1 samples of size d
- x2 : array, shape (n2,d), optional
+ x2 : array-like, shape (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
metric : str | callable, optional
- Name of the metric to be computed (full list in the doc of scipy), If a string,
- the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
- 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
- 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
+ accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
+ 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns
-------
- M : np.array (n1,n2)
+ M : array-like, shape (n1, n2)
distance matrix computed with given metric
"""
@@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'):
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
- return cdist(x1, x2, metric=metric)
+ elif metric == "euclidean":
+ return euclidean_distances(x1, x2, squared=False)
+ else:
+ if not get_backend(x1, x2).__name__ == 'numpy':
+ raise NotImplementedError()
+ else:
+ return cdist(x1, x2, metric=metric)
def dist0(n, method='lin_square'):