diff options
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 1502 |
1 files changed, 1502 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py new file mode 100644 index 0000000..a044f84 --- /dev/null +++ b/ot/backend.py @@ -0,0 +1,1502 @@ +# -*- coding: utf-8 -*- +""" +Multi-lib backend for POT + +The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, +or Jax, POT code should work nonetheless. +To achieve that, POT provides backend classes which implements functions in their respective backend +imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. + +Examples +-------- + +>>> from ot.utils import list_to_array +>>> from ot.backend import get_backend +>>> def f(a, b): # the function does not know which backend to use +... a, b = list_to_array(a, b) # if a list in given, make it an array +... nx = get_backend(a, b) # infer the backend from the arguments +... c = nx.dot(a, b) # now use the backend to do any calculation +... return c +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import numpy as np +import scipy.special as scipy +from scipy.sparse import issparse, coo_matrix, csr_matrix + +try: + import torch + torch_type = torch.Tensor +except ImportError: + torch = False + torch_type = float + +try: + import jax + import jax.numpy as jnp + import jax.scipy.special as jscipy + jax_type = jax.numpy.ndarray +except ImportError: + jax = False + jax_type = float + +str_type_error = "All array should be from the same type/backend. Current types are : {}" + + +def get_backend_list(): + """Returns the list of available backends""" + lst = [NumpyBackend(), ] + + if torch: + lst.append(TorchBackend()) + + if jax: + lst.append(JaxBackend()) + + return lst + + +def get_backend(*args): + """Returns the proper backend for a list of input arrays + + Also raises TypeError if all arrays are not from the same backend + """ + # check that some arrays given + if not len(args) > 0: + raise ValueError(" The function takes at least one parameter") + # check all same type + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + + if isinstance(args[0], np.ndarray): + return NumpyBackend() + elif isinstance(args[0], torch_type): + return TorchBackend() + elif isinstance(args[0], jax_type): + return JaxBackend() + else: + raise ValueError("Unknown type of non implemented backend.") + + +def to_numpy(*args): + """Returns numpy arrays from any compatible backend""" + + if len(args) == 1: + return get_backend(args[0]).to_numpy(args[0]) + else: + return [get_backend(a).to_numpy(a) for a in args] + + +class Backend(): + """ + Backend abstract class. + Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend` + + - The `__name__` class attribute refers to the name of the backend. + - The `__type__` class attribute refers to the data structure used by the backend. + """ + + __name__ = None + __type__ = None + __type_list__ = None + + rng_ = None + + def __str__(self): + return self.__name__ + + # convert to numpy + def to_numpy(self, a): + """Returns the numpy version of a tensor""" + raise NotImplementedError() + + # convert from numpy + def from_numpy(self, a, type_as=None): + """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" + raise NotImplementedError() + + def set_gradients(self, val, inputs, grads): + """Define the gradients for the value val wrt the inputs """ + raise NotImplementedError() + + def zeros(self, shape, type_as=None): + r""" + Creates a tensor full of zeros. + + This function follows the api from :any:`numpy.zeros` + + See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html + """ + raise NotImplementedError() + + def ones(self, shape, type_as=None): + r""" + Creates a tensor full of ones. + + This function follows the api from :any:`numpy.ones` + + See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html + """ + raise NotImplementedError() + + def arange(self, stop, start=0, step=1, type_as=None): + r""" + Returns evenly spaced values within a given interval. + + This function follows the api from :any:`numpy.arange` + + See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html + """ + raise NotImplementedError() + + def full(self, shape, fill_value, type_as=None): + r""" + Creates a tensor with given shape, filled with given value. + + This function follows the api from :any:`numpy.full` + + See: https://numpy.org/doc/stable/reference/generated/numpy.full.html + """ + raise NotImplementedError() + + def eye(self, N, M=None, type_as=None): + r""" + Creates the identity matrix of given size. + + This function follows the api from :any:`numpy.eye` + + See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html + """ + raise NotImplementedError() + + def sum(self, a, axis=None, keepdims=False): + r""" + Sums tensor elements over given dimensions. + + This function follows the api from :any:`numpy.sum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html + """ + raise NotImplementedError() + + def cumsum(self, a, axis=None): + r""" + Returns the cumulative sum of tensor elements over given dimensions. + + This function follows the api from :any:`numpy.cumsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html + """ + raise NotImplementedError() + + def max(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follows the api from :any:`numpy.amax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html + """ + raise NotImplementedError() + + def min(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follows the api from :any:`numpy.amin` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html + """ + raise NotImplementedError() + + def maximum(self, a, b): + r""" + Returns element-wise maximum of array elements. + + This function follows the api from :any:`numpy.maximum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html + """ + raise NotImplementedError() + + def minimum(self, a, b): + r""" + Returns element-wise minimum of array elements. + + This function follows the api from :any:`numpy.minimum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html + """ + raise NotImplementedError() + + def dot(self, a, b): + r""" + Returns the dot product of two tensors. + + This function follows the api from :any:`numpy.dot` + + See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html + """ + raise NotImplementedError() + + def abs(self, a): + r""" + Computes the absolute value element-wise. + + This function follows the api from :any:`numpy.absolute` + + See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html + """ + raise NotImplementedError() + + def exp(self, a): + r""" + Computes the exponential value element-wise. + + This function follows the api from :any:`numpy.exp` + + See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html + """ + raise NotImplementedError() + + def log(self, a): + r""" + Computes the natural logarithm, element-wise. + + This function follows the api from :any:`numpy.log` + + See: https://numpy.org/doc/stable/reference/generated/numpy.log.html + """ + raise NotImplementedError() + + def sqrt(self, a): + r""" + Returns the non-ngeative square root of a tensor, element-wise. + + This function follows the api from :any:`numpy.sqrt` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html + """ + raise NotImplementedError() + + def power(self, a, exponents): + r""" + First tensor elements raised to powers from second tensor, element-wise. + + This function follows the api from :any:`numpy.power` + + See: https://numpy.org/doc/stable/reference/generated/numpy.power.html + """ + raise NotImplementedError() + + def norm(self, a): + r""" + Computes the matrix frobenius norm. + + This function follows the api from :any:`numpy.linalg.norm` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html + """ + raise NotImplementedError() + + def any(self, a): + r""" + Tests whether any tensor element along given dimensions evaluates to True. + + This function follows the api from :any:`numpy.any` + + See: https://numpy.org/doc/stable/reference/generated/numpy.any.html + """ + raise NotImplementedError() + + def isnan(self, a): + r""" + Tests element-wise for NaN and returns result as a boolean tensor. + + This function follows the api from :any:`numpy.isnan` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html + """ + raise NotImplementedError() + + def isinf(self, a): + r""" + Tests element-wise for positive or negative infinity and returns result as a boolean tensor. + + This function follows the api from :any:`numpy.isinf` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html + """ + raise NotImplementedError() + + def einsum(self, subscripts, *operands): + r""" + Evaluates the Einstein summation convention on the operands. + + This function follows the api from :any:`numpy.einsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html + """ + raise NotImplementedError() + + def sort(self, a, axis=-1): + r""" + Returns a sorted copy of a tensor. + + This function follows the api from :any:`numpy.sort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html + """ + raise NotImplementedError() + + def argsort(self, a, axis=None): + r""" + Returns the indices that would sort a tensor. + + This function follows the api from :any:`numpy.argsort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html + """ + raise NotImplementedError() + + def searchsorted(self, a, v, side='left'): + r""" + Finds indices where elements should be inserted to maintain order in given tensor. + + This function follows the api from :any:`numpy.searchsorted` + + See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html + """ + raise NotImplementedError() + + def flip(self, a, axis=None): + r""" + Reverses the order of elements in a tensor along given dimensions. + + This function follows the api from :any:`numpy.flip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html + """ + raise NotImplementedError() + + def clip(self, a, a_min, a_max): + """ + Limits the values in a tensor. + + This function follows the api from :any:`numpy.clip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html + """ + raise NotImplementedError() + + def repeat(self, a, repeats, axis=None): + r""" + Repeats elements of a tensor. + + This function follows the api from :any:`numpy.repeat` + + See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html + """ + raise NotImplementedError() + + def take_along_axis(self, arr, indices, axis): + r""" + Gathers elements of a tensor along given dimensions. + + This function follows the api from :any:`numpy.take_along_axis` + + See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html + """ + raise NotImplementedError() + + def concatenate(self, arrays, axis=0): + r""" + Joins a sequence of tensors along an existing dimension. + + This function follows the api from :any:`numpy.concatenate` + + See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html + """ + raise NotImplementedError() + + def zero_pad(self, a, pad_width): + r""" + Pads a tensor. + + This function follows the api from :any:`numpy.pad` + + See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + """ + raise NotImplementedError() + + def argmax(self, a, axis=None): + r""" + Returns the indices of the maximum values of a tensor along given dimensions. + + This function follows the api from :any:`numpy.argmax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html + """ + raise NotImplementedError() + + def mean(self, a, axis=None): + r""" + Computes the arithmetic mean of a tensor along given dimensions. + + This function follows the api from :any:`numpy.mean` + + See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html + """ + raise NotImplementedError() + + def std(self, a, axis=None): + r""" + Computes the standard deviation of a tensor along given dimensions. + + This function follows the api from :any:`numpy.std` + + See: https://numpy.org/doc/stable/reference/generated/numpy.std.html + """ + raise NotImplementedError() + + def linspace(self, start, stop, num): + r""" + Returns a specified number of evenly spaced values over a given interval. + + This function follows the api from :any:`numpy.linspace` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html + """ + raise NotImplementedError() + + def meshgrid(self, a, b): + r""" + Returns coordinate matrices from coordinate vectors (Numpy convention). + + This function follows the api from :any:`numpy.meshgrid` + + See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html + """ + raise NotImplementedError() + + def diag(self, a, k=0): + r""" + Extracts or constructs a diagonal tensor. + + This function follows the api from :any:`numpy.diag` + + See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html + """ + raise NotImplementedError() + + def unique(self, a): + r""" + Finds unique elements of given tensor. + + This function follows the api from :any:`numpy.unique` + + See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html + """ + raise NotImplementedError() + + def logsumexp(self, a, axis=None): + r""" + Computes the log of the sum of exponentials of input elements. + + This function follows the api from :any:`scipy.special.logsumexp` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html + """ + raise NotImplementedError() + + def stack(self, arrays, axis=0): + r""" + Joins a sequence of tensors along a new dimension. + + This function follows the api from :any:`numpy.stack` + + See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html + """ + raise NotImplementedError() + + def outer(self, a, b): + r""" + Computes the outer product between two vectors. + + This function follows the api from :any:`numpy.outer` + + See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html + """ + raise NotImplementedError() + + def reshape(self, a, shape): + r""" + Gives a new shape to a tensor without changing its data. + + This function follows the api from :any:`numpy.reshape` + + See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html + """ + raise NotImplementedError() + + def seed(self, seed=None): + r""" + Sets the seed for the random generator. + + This function follows the api from :any:`numpy.random.seed` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html + """ + raise NotImplementedError() + + def rand(self, *size, type_as=None): + r""" + Generate uniform random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def randn(self, *size, type_as=None): + r""" + Generate normal Gaussian random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + r""" + Creates a sparse tensor in COOrdinate format. + + This function follows the api from :any:`scipy.sparse.coo_matrix` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html + """ + raise NotImplementedError() + + def issparse(self, a): + r""" + Checks whether or not the input tensor is a sparse tensor. + + This function follows the api from :any:`scipy.sparse.issparse` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html + """ + raise NotImplementedError() + + def tocsr(self, a): + r""" + Converts this matrix to Compressed Sparse Row format. + + This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html + """ + raise NotImplementedError() + + def eliminate_zeros(self, a, threshold=0.): + r""" + Removes entries smaller than the given threshold from the sparse tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros` + + See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html + """ + raise NotImplementedError() + + def todense(self, a): + r""" + Converts a sparse tensor to a dense tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.toarray` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html + """ + raise NotImplementedError() + + def where(self, condition, x, y): + r""" + Returns elements chosen from x or y depending on condition. + + This function follows the api from :any:`numpy.where` + + See: https://numpy.org/doc/stable/reference/generated/numpy.where.html + """ + raise NotImplementedError() + + def copy(self, a): + r""" + Returns a copy of the given tensor. + + This function follows the api from :any:`numpy.copy` + + See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html + """ + raise NotImplementedError() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + r""" + Returns True if two arrays are element-wise equal within a tolerance. + + This function follows the api from :any:`numpy.allclose` + + See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html + """ + raise NotImplementedError() + + def dtype_device(self, a): + r""" + Returns the dtype and the device of the given tensor. + """ + raise NotImplementedError() + + def assert_same_dtype_device(self, a, b): + r""" + Checks whether or not the two given inputs have the same dtype as well as the same device + """ + raise NotImplementedError() + + +class NumpyBackend(Backend): + """ + NumPy implementation of the backend + + - `__name__` is "numpy" + - `__type__` is np.ndarray + """ + + __name__ = 'numpy' + __type__ = np.ndarray + __type_list__ = [np.array(1, dtype=np.float32), + np.array(1, dtype=np.float64)] + + rng_ = np.random.RandomState() + + def to_numpy(self, a): + return a + + def from_numpy(self, a, type_as=None): + if type_as is None: + return a + elif isinstance(a, float): + return a + else: + return a.astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # No gradients for numpy + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return np.zeros(shape) + else: + return np.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return np.ones(shape) + else: + return np.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return np.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return np.full(shape, fill_value) + else: + return np.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return np.eye(N, M) + else: + return np.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return np.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return np.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return np.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return np.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return np.maximum(a, b) + + def minimum(self, a, b): + return np.minimum(a, b) + + def dot(self, a, b): + return np.dot(a, b) + + def abs(self, a): + return np.abs(a) + + def exp(self, a): + return np.exp(a) + + def log(self, a): + return np.log(a) + + def sqrt(self, a): + return np.sqrt(a) + + def power(self, a, exponents): + return np.power(a, exponents) + + def norm(self, a): + return np.sqrt(np.sum(np.square(a))) + + def any(self, a): + return np.any(a) + + def isnan(self, a): + return np.isnan(a) + + def isinf(self, a): + return np.isinf(a) + + def einsum(self, subscripts, *operands): + return np.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return np.sort(a, axis) + + def argsort(self, a, axis=-1): + return np.argsort(a, axis) + + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return np.searchsorted(a, v, side) + else: + # this is a not very efficient way to make numpy + # searchsorted work on 2d arrays + ret = np.empty(v.shape, dtype=int) + for i in range(a.shape[0]): + ret[i, :] = np.searchsorted(a[i, :], v[i, :], side) + return ret + + def flip(self, a, axis=None): + return np.flip(a, axis) + + def outer(self, a, b): + return np.outer(a, b) + + def clip(self, a, a_min, a_max): + return np.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return np.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return np.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return np.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return np.pad(a, pad_width) + + def argmax(self, a, axis=None): + return np.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return np.mean(a, axis=axis) + + def std(self, a, axis=None): + return np.std(a, axis=axis) + + def linspace(self, start, stop, num): + return np.linspace(start, stop, num) + + def meshgrid(self, a, b): + return np.meshgrid(a, b) + + def diag(self, a, k=0): + return np.diag(a, k) + + def unique(self, a): + return np.unique(a) + + def logsumexp(self, a, axis=None): + return scipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return np.stack(arrays, axis) + + def reshape(self, a, shape): + return np.reshape(a, shape) + + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + return self.rng_.rand(*size) + + def randn(self, *size, type_as=None): + return self.rng_.randn(*size) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return coo_matrix((data, (rows, cols)), shape=shape) + else: + return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype) + + def issparse(self, a): + return issparse(a) + + def tocsr(self, a): + if self.issparse(a): + return a.tocsr() + else: + return csr_matrix(a) + + def eliminate_zeros(self, a, threshold=0.): + if threshold > 0: + if self.issparse(a): + a.data[self.abs(a.data) <= threshold] = 0 + else: + a[self.abs(a) <= threshold] = 0 + if self.issparse(a): + a.eliminate_zeros() + return a + + def todense(self, a): + if self.issparse(a): + return a.toarray() + else: + return a + + def where(self, condition, x, y): + return np.where(condition, x, y) + + def copy(self, a): + return a.copy() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + if hasattr(a, "dtype"): + return a.dtype, "cpu" + else: + return type(a), "cpu" + + def assert_same_dtype_device(self, a, b): + # numpy has implicit type conversion so we automatically validate the test + pass + + +class JaxBackend(Backend): + """ + JAX implementation of the backend + + - `__name__` is "jax" + - `__type__` is jax.numpy.ndarray + """ + + __name__ = 'jax' + __type__ = jax_type + __type_list__ = None + + rng_ = None + + def __init__(self): + self.rng_ = jax.random.PRNGKey(42) + + for d in jax.devices(): + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d)] + + def to_numpy(self, a): + return np.array(a) + + def _change_device(self, a, type_as): + return jax.device_put(a, type_as.device_buffer.device()) + + def from_numpy(self, a, type_as=None): + if type_as is None: + return jnp.array(a) + else: + return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) + + def set_gradients(self, val, inputs, grads): + from jax.flatten_util import ravel_pytree + val, = jax.lax.stop_gradient((val,)) + + ravelled_inputs, _ = ravel_pytree(inputs) + ravelled_grads, _ = ravel_pytree(grads) + + aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2 + aux = aux - jax.lax.stop_gradient(aux) + + val, = jax.tree_map(lambda z: z + aux, (val,)) + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return jnp.zeros(shape) + else: + return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as) + + def ones(self, shape, type_as=None): + if type_as is None: + return jnp.ones(shape) + else: + return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as) + + def arange(self, stop, start=0, step=1, type_as=None): + return jnp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return jnp.full(shape, fill_value) + else: + return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return jnp.eye(N, M) + else: + return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as) + + def sum(self, a, axis=None, keepdims=False): + return jnp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return jnp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return jnp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return jnp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return jnp.maximum(a, b) + + def minimum(self, a, b): + return jnp.minimum(a, b) + + def dot(self, a, b): + return jnp.dot(a, b) + + def abs(self, a): + return jnp.abs(a) + + def exp(self, a): + return jnp.exp(a) + + def log(self, a): + return jnp.log(a) + + def sqrt(self, a): + return jnp.sqrt(a) + + def power(self, a, exponents): + return jnp.power(a, exponents) + + def norm(self, a): + return jnp.sqrt(jnp.sum(jnp.square(a))) + + def any(self, a): + return jnp.any(a) + + def isnan(self, a): + return jnp.isnan(a) + + def isinf(self, a): + return jnp.isinf(a) + + def einsum(self, subscripts, *operands): + return jnp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return jnp.sort(a, axis) + + def argsort(self, a, axis=-1): + return jnp.argsort(a, axis) + + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return jnp.searchsorted(a, v, side) + else: + # this is a not very efficient way to make jax numpy + # searchsorted work on 2d arrays + return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])]) + + def flip(self, a, axis=None): + return jnp.flip(a, axis) + + def outer(self, a, b): + return jnp.outer(a, b) + + def clip(self, a, a_min, a_max): + return jnp.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return jnp.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return jnp.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return jnp.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return jnp.pad(a, pad_width) + + def argmax(self, a, axis=None): + return jnp.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return jnp.mean(a, axis=axis) + + def std(self, a, axis=None): + return jnp.std(a, axis=axis) + + def linspace(self, start, stop, num): + return jnp.linspace(start, stop, num) + + def meshgrid(self, a, b): + return jnp.meshgrid(a, b) + + def diag(self, a, k=0): + return jnp.diag(a, k) + + def unique(self, a): + return jnp.unique(a) + + def logsumexp(self, a, axis=None): + return jscipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return jnp.stack(arrays, axis) + + def reshape(self, a, shape): + return jnp.reshape(a, shape) + + def seed(self, seed=None): + if seed is not None: + self.rng_ = jax.random.PRNGKey(seed) + + def rand(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.uniform(subkey, shape=size) + + def randn(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.normal(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.normal(subkey, shape=size) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + # Currently, JAX does not support sparse matrices + data = self.to_numpy(data) + rows = self.to_numpy(rows) + cols = self.to_numpy(cols) + nx = NumpyBackend() + coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as) + matrix = nx.todense(coo_matrix) + return self.from_numpy(matrix) + + def issparse(self, a): + # Currently, JAX does not support sparse matrices + return False + + def tocsr(self, a): + # Currently, JAX does not support sparse matrices + return a + + def eliminate_zeros(self, a, threshold=0.): + # Currently, JAX does not support sparse matrices + if threshold > 0: + return self.where( + self.abs(a) <= threshold, + self.zeros((1,), type_as=a), + a + ) + return a + + def todense(self, a): + # Currently, JAX does not support sparse matrices + return a + + def where(self, condition, x, y): + return jnp.where(condition, x, y) + + def copy(self, a): + # No need to copy, JAX arrays are immutable + return a + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device_buffer.device() + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + + +class TorchBackend(Backend): + """ + PyTorch implementation of the backend + + - `__name__` is "torch" + - `__type__` is torch.Tensor + """ + + __name__ = 'torch' + __type__ = torch_type + __type_list__ = None + + rng_ = None + + def __init__(self): + + self.rng_ = torch.Generator() + self.rng_.seed() + + self.__type_list__ = [torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64)] + + if torch.cuda.is_available(): + self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + + from torch.autograd import Function + + # define a function that takes inputs val and grads + # ad returns a val tensor with proper gradients + class ValFunction(Function): + + @staticmethod + def forward(ctx, val, grads, *inputs): + ctx.grads = grads + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return (None, None) + ctx.grads + + self.ValFunction = ValFunction + + def to_numpy(self, a): + return a.cpu().detach().numpy() + + def from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) + if type_as is None: + return torch.from_numpy(a) + else: + return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) + + def set_gradients(self, val, inputs, grads): + + Func = self.ValFunction() + + res = Func.apply(val, grads, *inputs) + + return res + + def zeros(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) + if type_as is None: + return torch.zeros(shape) + else: + return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) + + def ones(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) + if type_as is None: + return torch.ones(shape) + else: + return torch.ones(shape, dtype=type_as.dtype, device=type_as.device) + + def arange(self, stop, start=0, step=1, type_as=None): + if type_as is None: + return torch.arange(start, stop, step) + else: + return torch.arange(start, stop, step, device=type_as.device) + + def full(self, shape, fill_value, type_as=None): + if isinstance(shape, int): + shape = (shape,) + if type_as is None: + return torch.full(shape, fill_value) + else: + return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device) + + def eye(self, N, M=None, type_as=None): + if M is None: + M = N + if type_as is None: + return torch.eye(N, m=M) + else: + return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device) + + def sum(self, a, axis=None, keepdims=False): + if axis is None: + return torch.sum(a) + else: + return torch.sum(a, axis, keepdim=keepdims) + + def cumsum(self, a, axis=None): + if axis is None: + return torch.cumsum(a.flatten(), 0) + else: + return torch.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + if axis is None: + return torch.max(a) + else: + return torch.max(a, axis, keepdim=keepdims)[0] + + def min(self, a, axis=None, keepdims=False): + if axis is None: + return torch.min(a) + else: + return torch.min(a, axis, keepdim=keepdims)[0] + + def maximum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + if hasattr(torch, "maximum"): + return torch.maximum(a, b) + else: + return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] + + def minimum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + if hasattr(torch, "minimum"): + return torch.minimum(a, b) + else: + return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] + + def dot(self, a, b): + return torch.matmul(a, b) + + def abs(self, a): + return torch.abs(a) + + def exp(self, a): + return torch.exp(a) + + def log(self, a): + return torch.log(a) + + def sqrt(self, a): + return torch.sqrt(a) + + def power(self, a, exponents): + return torch.pow(a, exponents) + + def norm(self, a): + return torch.sqrt(torch.sum(torch.square(a))) + + def any(self, a): + return torch.any(a) + + def isnan(self, a): + return torch.isnan(a) + + def isinf(self, a): + return torch.isinf(a) + + def einsum(self, subscripts, *operands): + return torch.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + sorted0, indices = torch.sort(a, dim=axis) + return sorted0 + + def argsort(self, a, axis=-1): + sorted, indices = torch.sort(a, dim=axis) + return indices + + def searchsorted(self, a, v, side='left'): + right = (side != 'left') + return torch.searchsorted(a, v, right=right) + + def flip(self, a, axis=None): + if axis is None: + return torch.flip(a, tuple(i for i in range(len(a.shape)))) + if isinstance(axis, int): + return torch.flip(a, (axis,)) + else: + return torch.flip(a, dims=axis) + + def outer(self, a, b): + return torch.outer(a, b) + + def clip(self, a, a_min, a_max): + return torch.clamp(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return torch.repeat_interleave(a, repeats, dim=axis) + + def take_along_axis(self, arr, indices, axis): + return torch.gather(arr, axis, indices) + + def concatenate(self, arrays, axis=0): + return torch.cat(arrays, dim=axis) + + def zero_pad(self, a, pad_width): + from torch.nn.functional import pad + # pad_width is an array of ndim tuples indicating how many 0 before and after + # we need to add. We first need to make it compliant with torch syntax, that + # starts with the last dim, then second last, etc. + how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) + return pad(a, how_pad) + + def argmax(self, a, axis=None): + return torch.argmax(a, dim=axis) + + def mean(self, a, axis=None): + if axis is not None: + return torch.mean(a, dim=axis) + else: + return torch.mean(a) + + def std(self, a, axis=None): + if axis is not None: + return torch.std(a, dim=axis, unbiased=False) + else: + return torch.std(a, unbiased=False) + + def linspace(self, start, stop, num): + return torch.linspace(start, stop, num, dtype=torch.float64) + + def meshgrid(self, a, b): + X, Y = torch.meshgrid(a, b) + return X.T, Y.T + + def diag(self, a, k=0): + return torch.diag(a, diagonal=k) + + def unique(self, a): + return torch.unique(a) + + def logsumexp(self, a, axis=None): + if axis is not None: + return torch.logsumexp(a, dim=axis) + else: + return torch.logsumexp(a, dim=tuple(range(len(a.shape)))) + + def stack(self, arrays, axis=0): + return torch.stack(arrays, dim=axis) + + def reshape(self, a, shape): + return torch.reshape(a, shape) + + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_.manual_seed(seed) + elif isinstance(seed, torch.Generator): + self.rng_ = seed + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + if type_as is not None: + return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + else: + return torch.rand(size=size, generator=self.rng_) + + def randn(self, *size, type_as=None): + if type_as is not None: + return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + else: + return torch.randn(size=size, generator=self.rng_) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) + else: + return torch.sparse_coo_tensor( + torch.stack([rows, cols]), data, size=shape, + dtype=type_as.dtype, device=type_as.device + ) + + def issparse(self, a): + return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False) + + def tocsr(self, a): + # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support + return self.todense(a) + + def eliminate_zeros(self, a, threshold=0.): + if self.issparse(a): + if threshold > 0: + mask = self.abs(a) <= threshold + mask = ~mask + mask = mask.nonzero() + else: + mask = a._values().nonzero() + nv = a._values().index_select(0, mask.view(-1)) + ni = a._indices().index_select(1, mask.view(-1)) + return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a) + else: + if threshold > 0: + a[self.abs(a) <= threshold] = 0 + return a + + def todense(self, a): + if self.issparse(a): + return a.to_dense() + else: + return a + + def where(self, condition, x, y): + return torch.where(condition, x, y) + + def copy(self, a): + return torch.clone(a) + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" |