From ca69658400dc2ef6a7d3e531acffcd107400085f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 3 Dec 2021 12:37:05 +0100 Subject: [MRG] Cupy backend (#315) * Cupy backend * pep8 + bug * working even if cupy not installed * attempt to force codecov to ignore cupy because no gpu can be used for testing on github * docstring * readme --- README.md | 1 + ot/backend.py | 302 ++++++++++++++++++++++++++++++++++++++++++++++++++- ot/gromov.py | 35 +++--- ot/optim.py | 2 +- test/test_backend.py | 21 +++- test/test_bregman.py | 1 + test/test_gromov.py | 19 ++-- 7 files changed, 355 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index ff8056a..18064a3 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,7 @@ The contributors to this library are * [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) +* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/ot/backend.py b/ot/backend.py index fa164c3..1630ac4 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -3,7 +3,7 @@ Multi-lib backend for POT The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, -or Jax, POT code should work nonetheless. +Jax, or Cupy, POT code should work nonetheless. To achieve that, POT provides backend classes which implements functions in their respective backend imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. @@ -44,6 +44,14 @@ except ImportError: jax = False jax_type = float +try: + import cupy as cp + import cupyx + cp_type = cp.ndarray +except ImportError: + cp = False + cp_type = float + str_type_error = "All array should be from the same type/backend. Current types are : {}" @@ -57,6 +65,9 @@ def get_backend_list(): if jax: lst.append(JaxBackend()) + if cp: + lst.append(CupyBackend()) + return lst @@ -78,6 +89,8 @@ def get_backend(*args): return TorchBackend() elif isinstance(args[0], jax_type): return JaxBackend() + elif isinstance(args[0], cp_type): + return CupyBackend() else: raise ValueError("Unknown type of non implemented backend.") @@ -94,7 +107,8 @@ def to_numpy(*args): class Backend(): """ Backend abstract class. - Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend` + Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`, + :py:class:`CupyBackend` - The `__name__` class attribute refers to the name of the backend. - The `__type__` class attribute refers to the data structure used by the backend. @@ -1500,3 +1514,287 @@ class TorchBackend(Backend): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + + +class CupyBackend(Backend): # pragma: no cover + """ + CuPy implementation of the backend + + - `__name__` is "cupy" + - `__type__` is cp.ndarray + """ + + __name__ = 'cupy' + __type__ = cp_type + __type_list__ = None + + rng_ = None + + def __init__(self): + self.rng_ = cp.random.RandomState() + + self.__type_list__ = [ + cp.array(1, dtype=cp.float32), + cp.array(1, dtype=cp.float64) + ] + + def to_numpy(self, a): + return cp.asnumpy(a) + + def from_numpy(self, a, type_as=None): + if type_as is None: + return cp.asarray(a) + else: + with cp.cuda.Device(type_as.device): + return cp.asarray(a, dtype=type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # No gradients for cupy + return val + + def zeros(self, shape, type_as=None): + if isinstance(shape, (list, tuple)): + shape = tuple(int(i) for i in shape) + if type_as is None: + return cp.zeros(shape) + else: + with cp.cuda.Device(type_as.device): + return cp.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if isinstance(shape, (list, tuple)): + shape = tuple(int(i) for i in shape) + if type_as is None: + return cp.ones(shape) + else: + with cp.cuda.Device(type_as.device): + return cp.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return cp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if isinstance(shape, (list, tuple)): + shape = tuple(int(i) for i in shape) + if type_as is None: + return cp.full(shape, fill_value) + else: + with cp.cuda.Device(type_as.device): + return cp.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return cp.eye(N, M) + else: + with cp.cuda.Device(type_as.device): + return cp.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return cp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return cp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return cp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return cp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return cp.maximum(a, b) + + def minimum(self, a, b): + return cp.minimum(a, b) + + def abs(self, a): + return cp.abs(a) + + def exp(self, a): + return cp.exp(a) + + def log(self, a): + return cp.log(a) + + def sqrt(self, a): + return cp.sqrt(a) + + def power(self, a, exponents): + return cp.power(a, exponents) + + def dot(self, a, b): + return cp.dot(a, b) + + def norm(self, a): + return cp.sqrt(cp.sum(cp.square(a))) + + def any(self, a): + return cp.any(a) + + def isnan(self, a): + return cp.isnan(a) + + def isinf(self, a): + return cp.isinf(a) + + def einsum(self, subscripts, *operands): + return cp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return cp.sort(a, axis) + + def argsort(self, a, axis=-1): + return cp.argsort(a, axis) + + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return cp.searchsorted(a, v, side) + else: + # this is a not very efficient way to make numpy + # searchsorted work on 2d arrays + ret = cp.empty(v.shape, dtype=int) + for i in range(a.shape[0]): + ret[i, :] = cp.searchsorted(a[i, :], v[i, :], side) + return ret + + def flip(self, a, axis=None): + return cp.flip(a, axis) + + def outer(self, a, b): + return cp.outer(a, b) + + def clip(self, a, a_min, a_max): + return cp.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return cp.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return cp.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return cp.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return cp.pad(a, pad_width) + + def argmax(self, a, axis=None): + return cp.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return cp.mean(a, axis=axis) + + def std(self, a, axis=None): + return cp.std(a, axis=axis) + + def linspace(self, start, stop, num): + return cp.linspace(start, stop, num) + + def meshgrid(self, a, b): + return cp.meshgrid(a, b) + + def diag(self, a, k=0): + return cp.diag(a, k) + + def unique(self, a): + return cp.unique(a) + + def logsumexp(self, a, axis=None): + # Taken from + # https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127 + a_max = cp.amax(a, axis=axis, keepdims=True) + + if a_max.ndim > 0: + a_max[~cp.isfinite(a_max)] = 0 + elif not cp.isfinite(a_max): + a_max = 0 + + tmp = cp.exp(a - a_max) + s = cp.sum(tmp, axis=axis) + out = cp.log(s) + a_max = cp.squeeze(a_max, axis=axis) + out += a_max + return out + + def stack(self, arrays, axis=0): + return cp.stack(arrays, axis) + + def reshape(self, a, shape): + return cp.reshape(a, shape) + + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + if type_as is None: + return self.rng_.rand(*size) + else: + with cp.cuda.Device(type_as.device): + return self.rng_.rand(*size, dtype=type_as.dtype) + + def randn(self, *size, type_as=None): + if type_as is None: + return self.rng_.randn(*size) + else: + with cp.cuda.Device(type_as.device): + return self.rng_.randn(*size, dtype=type_as.dtype) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + data = self.from_numpy(data) + rows = self.from_numpy(rows) + cols = self.from_numpy(cols) + if type_as is None: + return cupyx.scipy.sparse.coo_matrix( + (data, (rows, cols)), shape=shape + ) + else: + with cp.cuda.Device(type_as.device): + return cupyx.scipy.sparse.coo_matrix( + (data, (rows, cols)), shape=shape, dtype=type_as.dtype + ) + + def issparse(self, a): + return cupyx.scipy.sparse.issparse(a) + + def tocsr(self, a): + if self.issparse(a): + return a.tocsr() + else: + return cupyx.scipy.sparse.csr_matrix(a) + + def eliminate_zeros(self, a, threshold=0.): + if threshold > 0: + if self.issparse(a): + a.data[self.abs(a.data) <= threshold] = 0 + else: + a[self.abs(a) <= threshold] = 0 + if self.issparse(a): + a.eliminate_zeros() + return a + + def todense(self, a): + if self.issparse(a): + return a.toarray() + else: + return a + + def where(self, condition, x, y): + return cp.where(condition, x, y) + + def copy(self, a): + return a.copy() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return cp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + # cupy has implicit type conversion so + # we automatically validate the test for type + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" diff --git a/ot/gromov.py b/ot/gromov.py index ea667e4..2a70070 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -822,8 +822,12 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) - index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) - index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + index_i = generator.choice( + len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False + ) + index_j = generator.choice( + len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False + ) for i in range(nb_samples_p): if nx.issparse(T): @@ -836,13 +840,13 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, index_k[i] = generator.choice( len_q, size=nb_samples_q, - p=T_indexi / nx.sum(T_indexi), + p=nx.to_numpy(T_indexi / nx.sum(T_indexi)), replace=True ) index_l[i] = generator.choice( len_q, size=nb_samples_q, - p=T_indexj / nx.sum(T_indexj), + p=nx.to_numpy(T_indexj / nx.sum(T_indexj)), replace=True ) @@ -934,15 +938,17 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, index = np.zeros(2, dtype=int) # Initialize with default marginal - index[0] = generator.choice(len_p, size=1, p=p) - index[1] = generator.choice(len_q, size=1, p=q) + index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) + index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q)) T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) best_gw_dist_estimated = np.inf for cpt in range(max_iter): - index[0] = generator.choice(len_p, size=1, p=p) + index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) - index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) + index[1] = generator.choice( + len_q, size=1, p=nx.to_numpy(T_index0 / T_index0.sum()) + ) if alpha == 1: T = nx.tocsr( @@ -1071,13 +1077,16 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) for cpt in range(max_iter): - index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) + index0 = generator.choice( + len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False + ) Lik = 0 for i, index0_i in enumerate(index0): - index1 = generator.choice(len_q, - size=nb_samples_grad_q, - p=T[index0_i, :] / nx.sum(T[index0_i, :]), - replace=False) + index1 = generator.choice( + len_q, size=nb_samples_grad_q, + p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), + replace=False + ) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. if (not C_are_symmetric) and generator.rand(1) > 0.5: Lik += nx.mean(loss_fun( diff --git a/ot/optim.py b/ot/optim.py index 9b8a8f7..f25e2c9 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -88,7 +88,7 @@ def line_search_armijo( else: if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) - return alpha, fc[0], phi1 + return float(alpha), fc[0], phi1 def solve_linesearch( diff --git a/test/test_backend.py b/test/test_backend.py index 1832b91..2e7eecc 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -7,7 +7,7 @@ import ot import ot.backend -from ot.backend import torch, jax +from ot.backend import torch, jax, cp import pytest @@ -87,6 +87,20 @@ def test_get_backend(): with pytest.raises(ValueError): get_backend(A, B2) + if cp: + A2 = cp.asarray(A) + B2 = cp.asarray(B) + + nx = get_backend(A2) + assert nx.__name__ == 'cupy' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'cupy' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + def test_convert_between_backends(nx): @@ -240,7 +254,7 @@ def test_func_backends(nx): # Sparse tensors test sp_row = np.array([0, 3, 1, 0, 3]) sp_col = np.array([0, 3, 1, 2, 2]) - sp_data = np.array([4, 5, 7, 9, 0]) + sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64) lst_tot = [] @@ -393,7 +407,8 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('argsort') - A = nx.searchsorted(Mb, Mb, 'right') + tmp = nx.sort(Mb) + A = nx.searchsorted(tmp, tmp, 'right') lst_b.append(nx.to_numpy(A)) lst_name.append('searchsorted') diff --git a/test/test_bregman.py b/test/test_bregman.py index 830052d..f42ac6f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -888,6 +888,7 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) +@pytest.skip_backend("cupy") @pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") def test_screenkhorn(nx): diff --git a/test/test_gromov.py b/test/test_gromov.py index c4bc04c..5c181f2 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -54,9 +54,12 @@ def test_gromov(nx): gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) + gwb = nx.to_numpy(gwb) gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + gw_valb = nx.to_numpy( + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + ) G = log['T'] Gb = nx.to_numpy(logb['T']) @@ -188,6 +191,7 @@ def test_entropic_gromov(nx): C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True) + gwb = nx.to_numpy(gwb) G = log['T'] Gb = nx.to_numpy(logb['T']) @@ -287,8 +291,8 @@ def test_pointwise_gromov(nx): np.testing.assert_allclose( q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08) - np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08) + np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.0, atol=1e-08) + np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0, atol=1e-08) G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) @@ -298,8 +302,8 @@ def test_pointwise_gromov(nx): Gb = nx.to_numpy(nx.todense(Gb)) np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8) - np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8) + np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8) + np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8) @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -346,8 +350,8 @@ def test_sampled_gromov(nx): np.testing.assert_allclose( q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08) - np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08) + np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08) + np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08) def test_gromov_barycenter(nx): @@ -486,6 +490,7 @@ def test_fgw(nx): fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + fgwb = nx.to_numpy(fgwb) G = log['T'] Gb = nx.to_numpy(logb['T']) -- cgit v1.2.3