diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 14:30:44 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 14:30:44 +0200 |
commit | f45f7a68b221ec5b619b8fd8de797815a1eecf43 (patch) | |
tree | e3cc97cdf0c38e457303ceba32f7dadc20a12139 /ot/gpu/utils.py | |
parent | d258c7d6936410cd78189445a0260d983f7684d6 (diff) |
pep8
Diffstat (limited to 'ot/gpu/utils.py')
-rw-r--r-- | ot/gpu/utils.py | 34 |
1 files changed, 16 insertions, 18 deletions
diff --git a/ot/gpu/utils.py b/ot/gpu/utils.py index 6d0c853..d349a6d 100644 --- a/ot/gpu/utils.py +++ b/ot/gpu/utils.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Utility functions for GPU +Utility functions for GPU """ # Author: Remi Flamary <remi.flamary@unice.fr> @@ -9,9 +9,8 @@ Utility functions for GPU # # License: MIT License -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations - +import cupy as np # np used for matrix computation +import cupy as cp # cp used for cupy specific operations def euclidean_distances(a, b, squared=False, to_numpy=True): @@ -34,16 +33,16 @@ def euclidean_distances(a, b, squared=False, to_numpy=True): c : (n x m) np.ndarray or cupy.ndarray pairwise euclidean distance distance matrix """ - + a, b = to_gpu(a, b) - - a2=np.sum(np.square(a),1) - b2=np.sum(np.square(b),1) - - c=-2*np.dot(a,b.T) - c+=a2[:,None] - c+=b2[None,:] - + + a2 = np.sum(np.square(a), 1) + b2 = np.sum(np.square(b), 1) + + c = -2 * np.dot(a, b.T) + c += a2[:, None] + c += b2[None, :] + if not squared: np.sqrt(c, out=c) if to_numpy: @@ -51,6 +50,7 @@ def euclidean_distances(a, b, squared=False, to_numpy=True): else: return c + def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True): """Compute distance between samples in x1 and x2 on gpu @@ -61,8 +61,8 @@ def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True): matrix with n1 samples of size d x2 : np.array (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) - metric : str - Metric from 'sqeuclidean', 'euclidean', + metric : str + Metric from 'sqeuclidean', 'euclidean', Returns @@ -80,7 +80,6 @@ def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True): return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy) else: raise NotImplementedError - def to_gpu(*args): @@ -91,10 +90,9 @@ def to_gpu(*args): return cp.asarray(args[0]) - def to_np(*args): """ convert GPU arras to numpy and return them""" if len(args) > 1: return (cp.asnumpy(x) for x in args) else: - return cp.asnumpy(args[0])
\ No newline at end of file + return cp.asnumpy(args[0]) |