diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2021-12-03 12:37:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-03 12:37:05 +0100 |
commit | ca69658400dc2ef6a7d3e531acffcd107400085f (patch) | |
tree | b77a28821067be5240cec2082fa1f119b1cfd1cd /ot | |
parent | cb510644b2fd65e4ce216a7799ce7401f71548b8 (diff) |
[MRG] Cupy backend (#315)
* Cupy backend
* pep8 + bug
* working even if cupy not installed
* attempt to force codecov to ignore cupy because no gpu can be used for testing on github
* docstring
* readme
Diffstat (limited to 'ot')
-rw-r--r-- | ot/backend.py | 302 | ||||
-rw-r--r-- | ot/gromov.py | 35 | ||||
-rw-r--r-- | ot/optim.py | 2 |
3 files changed, 323 insertions, 16 deletions
diff --git a/ot/backend.py b/ot/backend.py index fa164c3..1630ac4 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -3,7 +3,7 @@ Multi-lib backend for POT The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, -or Jax, POT code should work nonetheless. +Jax, or Cupy, POT code should work nonetheless. To achieve that, POT provides backend classes which implements functions in their respective backend imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. @@ -44,6 +44,14 @@ except ImportError: jax = False jax_type = float +try: + import cupy as cp + import cupyx + cp_type = cp.ndarray +except ImportError: + cp = False + cp_type = float + str_type_error = "All array should be from the same type/backend. Current types are : {}" @@ -57,6 +65,9 @@ def get_backend_list(): if jax: lst.append(JaxBackend()) + if cp: + lst.append(CupyBackend()) + return lst @@ -78,6 +89,8 @@ def get_backend(*args): return TorchBackend() elif isinstance(args[0], jax_type): return JaxBackend() + elif isinstance(args[0], cp_type): + return CupyBackend() else: raise ValueError("Unknown type of non implemented backend.") @@ -94,7 +107,8 @@ def to_numpy(*args): class Backend(): """ Backend abstract class. - Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend` + Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`, + :py:class:`CupyBackend` - The `__name__` class attribute refers to the name of the backend. - The `__type__` class attribute refers to the data structure used by the backend. @@ -1500,3 +1514,287 @@ class TorchBackend(Backend): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + + +class CupyBackend(Backend): # pragma: no cover + """ + CuPy implementation of the backend + + - `__name__` is "cupy" + - `__type__` is cp.ndarray + """ + + __name__ = 'cupy' + __type__ = cp_type + __type_list__ = None + + rng_ = None + + def __init__(self): + self.rng_ = cp.random.RandomState() + + self.__type_list__ = [ + cp.array(1, dtype=cp.float32), + cp.array(1, dtype=cp.float64) + ] + + def to_numpy(self, a): + return cp.asnumpy(a) + + def from_numpy(self, a, type_as=None): + if type_as is None: + return cp.asarray(a) + else: + with cp.cuda.Device(type_as.device): + return cp.asarray(a, dtype=type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # No gradients for cupy + return val + + def zeros(self, shape, type_as=None): + if isinstance(shape, (list, tuple)): + shape = tuple(int(i) for i in shape) + if type_as is None: + return cp.zeros(shape) + else: + with cp.cuda.Device(type_as.device): + return cp.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if isinstance(shape, (list, tuple)): + shape = tuple(int(i) for i in shape) + if type_as is None: + return cp.ones(shape) + else: + with cp.cuda.Device(type_as.device): + return cp.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return cp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if isinstance(shape, (list, tuple)): + shape = tuple(int(i) for i in shape) + if type_as is None: + return cp.full(shape, fill_value) + else: + with cp.cuda.Device(type_as.device): + return cp.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return cp.eye(N, M) + else: + with cp.cuda.Device(type_as.device): + return cp.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return cp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return cp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return cp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return cp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return cp.maximum(a, b) + + def minimum(self, a, b): + return cp.minimum(a, b) + + def abs(self, a): + return cp.abs(a) + + def exp(self, a): + return cp.exp(a) + + def log(self, a): + return cp.log(a) + + def sqrt(self, a): + return cp.sqrt(a) + + def power(self, a, exponents): + return cp.power(a, exponents) + + def dot(self, a, b): + return cp.dot(a, b) + + def norm(self, a): + return cp.sqrt(cp.sum(cp.square(a))) + + def any(self, a): + return cp.any(a) + + def isnan(self, a): + return cp.isnan(a) + + def isinf(self, a): + return cp.isinf(a) + + def einsum(self, subscripts, *operands): + return cp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return cp.sort(a, axis) + + def argsort(self, a, axis=-1): + return cp.argsort(a, axis) + + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return cp.searchsorted(a, v, side) + else: + # this is a not very efficient way to make numpy + # searchsorted work on 2d arrays + ret = cp.empty(v.shape, dtype=int) + for i in range(a.shape[0]): + ret[i, :] = cp.searchsorted(a[i, :], v[i, :], side) + return ret + + def flip(self, a, axis=None): + return cp.flip(a, axis) + + def outer(self, a, b): + return cp.outer(a, b) + + def clip(self, a, a_min, a_max): + return cp.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return cp.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return cp.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return cp.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return cp.pad(a, pad_width) + + def argmax(self, a, axis=None): + return cp.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return cp.mean(a, axis=axis) + + def std(self, a, axis=None): + return cp.std(a, axis=axis) + + def linspace(self, start, stop, num): + return cp.linspace(start, stop, num) + + def meshgrid(self, a, b): + return cp.meshgrid(a, b) + + def diag(self, a, k=0): + return cp.diag(a, k) + + def unique(self, a): + return cp.unique(a) + + def logsumexp(self, a, axis=None): + # Taken from + # https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127 + a_max = cp.amax(a, axis=axis, keepdims=True) + + if a_max.ndim > 0: + a_max[~cp.isfinite(a_max)] = 0 + elif not cp.isfinite(a_max): + a_max = 0 + + tmp = cp.exp(a - a_max) + s = cp.sum(tmp, axis=axis) + out = cp.log(s) + a_max = cp.squeeze(a_max, axis=axis) + out += a_max + return out + + def stack(self, arrays, axis=0): + return cp.stack(arrays, axis) + + def reshape(self, a, shape): + return cp.reshape(a, shape) + + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + if type_as is None: + return self.rng_.rand(*size) + else: + with cp.cuda.Device(type_as.device): + return self.rng_.rand(*size, dtype=type_as.dtype) + + def randn(self, *size, type_as=None): + if type_as is None: + return self.rng_.randn(*size) + else: + with cp.cuda.Device(type_as.device): + return self.rng_.randn(*size, dtype=type_as.dtype) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + data = self.from_numpy(data) + rows = self.from_numpy(rows) + cols = self.from_numpy(cols) + if type_as is None: + return cupyx.scipy.sparse.coo_matrix( + (data, (rows, cols)), shape=shape + ) + else: + with cp.cuda.Device(type_as.device): + return cupyx.scipy.sparse.coo_matrix( + (data, (rows, cols)), shape=shape, dtype=type_as.dtype + ) + + def issparse(self, a): + return cupyx.scipy.sparse.issparse(a) + + def tocsr(self, a): + if self.issparse(a): + return a.tocsr() + else: + return cupyx.scipy.sparse.csr_matrix(a) + + def eliminate_zeros(self, a, threshold=0.): + if threshold > 0: + if self.issparse(a): + a.data[self.abs(a.data) <= threshold] = 0 + else: + a[self.abs(a) <= threshold] = 0 + if self.issparse(a): + a.eliminate_zeros() + return a + + def todense(self, a): + if self.issparse(a): + return a.toarray() + else: + return a + + def where(self, condition, x, y): + return cp.where(condition, x, y) + + def copy(self, a): + return a.copy() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return cp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + # cupy has implicit type conversion so + # we automatically validate the test for type + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" diff --git a/ot/gromov.py b/ot/gromov.py index ea667e4..2a70070 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -822,8 +822,12 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
- index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
- index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+ index_i = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
+ index_j = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
for i in range(nb_samples_p):
if nx.issparse(T):
@@ -836,13 +840,13 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, index_k[i] = generator.choice(
len_q,
size=nb_samples_q,
- p=T_indexi / nx.sum(T_indexi),
+ p=nx.to_numpy(T_indexi / nx.sum(T_indexi)),
replace=True
)
index_l[i] = generator.choice(
len_q,
size=nb_samples_q,
- p=T_indexj / nx.sum(T_indexj),
+ p=nx.to_numpy(T_indexj / nx.sum(T_indexj)),
replace=True
)
@@ -934,15 +938,17 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, index = np.zeros(2, dtype=int)
# Initialize with default marginal
- index[0] = generator.choice(len_p, size=1, p=p)
- index[1] = generator.choice(len_q, size=1, p=q)
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
+ index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
best_gw_dist_estimated = np.inf
for cpt in range(max_iter):
- index[0] = generator.choice(len_p, size=1, p=p)
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
- index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum())
+ index[1] = generator.choice(
+ len_q, size=1, p=nx.to_numpy(T_index0 / T_index0.sum())
+ )
if alpha == 1:
T = nx.tocsr(
@@ -1071,13 +1077,16 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
for cpt in range(max_iter):
- index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False)
+ index0 = generator.choice(
+ len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False
+ )
Lik = 0
for i, index0_i in enumerate(index0):
- index1 = generator.choice(len_q,
- size=nb_samples_grad_q,
- p=T[index0_i, :] / nx.sum(T[index0_i, :]),
- replace=False)
+ index1 = generator.choice(
+ len_q, size=nb_samples_grad_q,
+ p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
+ replace=False
+ )
# If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
if (not C_are_symmetric) and generator.rand(1) > 0.5:
Lik += nx.mean(loss_fun(
diff --git a/ot/optim.py b/ot/optim.py index 9b8a8f7..f25e2c9 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -88,7 +88,7 @@ def line_search_armijo( else: if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) - return alpha, fc[0], phi1 + return float(alpha), fc[0], phi1 def solve_linesearch( |