diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-10-25 11:36:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-25 11:36:21 +0200 |
commit | 7a65086dd340265d0223eb8ffb5c9a5152a82dff (patch) | |
tree | 300f4a1cd645516fba1e440691fe48830d781b5c /ot/backend.py | |
parent | 7af8c2147d61349f4d99ca33318a8a125e4569aa (diff) |
[MRG] Bregman backend (#280)
* Bregman
* Resolve conflicts
* Bug solve
* Bregman updated for JAX compatibility
* Tests coherence between backend improved
* No longer enforcing 64 bits operations on Jax except for tests
* Now using mixtures, to make backend dependent tests with less code
* Better test skipping code
* Pep8 + test optimizations
* redundancy removed
* Docs
* Typo corrected
* Typo
* Typo
* Docs
* Docs
* pep8
* Backend docs
* Prettier docs
* Mistake corrected
* small changes
* Better wording
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 581 |
1 files changed, 563 insertions, 18 deletions
diff --git a/ot/backend.py b/ot/backend.py index 2ed40af..a4a4757 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1,6 +1,22 @@ # -*- coding: utf-8 -*- """ Multi-lib backend for POT + +The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, +or Jax, POT code should work nonetheless. +To achieve that, POT provides backend classes which implements functions in their respective backend +imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. + +Examples +-------- + +>>> from ot.utils import list_to_array +>>> from ot.backend import get_backend +>>> def f(a, b): # the function does not know which backend to use +... a, b = list_to_array(a, b) # if a list in given, make it an array +... nx = get_backend(a, b) # infer the backend from the arguments +... c = nx.dot(a, b) # now use the backend to do any calculation +... return c """ # Author: Remi Flamary <remi.flamary@polytechnique.edu> @@ -9,6 +25,7 @@ Multi-lib backend for POT # License: MIT License import numpy as np +import scipy.special as scipy try: import torch @@ -20,6 +37,7 @@ except ImportError: try: import jax import jax.numpy as jnp + import jax.scipy.special as jscipy jax_type = jax.numpy.ndarray except ImportError: jax = False @@ -29,7 +47,7 @@ str_type_error = "All array should be from the same type/backend. Current types def get_backend_list(): - """ returns the list of available backends)""" + """Returns the list of available backends""" lst = [NumpyBackend(), ] if torch: @@ -42,7 +60,7 @@ def get_backend_list(): def get_backend(*args): - """returns the proper backend for a list of input arrays + """Returns the proper backend for a list of input arrays Also raises TypeError if all arrays are not from the same backend """ @@ -50,14 +68,12 @@ def get_backend(*args): if not len(args) > 0: raise ValueError(" The function takes at least one parameter") # check all same type + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) if isinstance(args[0], np.ndarray): - if not len(set(type(a) for a in args)) == 1: - raise ValueError(str_type_error.format([type(a) for a in args])) return NumpyBackend() - elif torch and isinstance(args[0], torch_type): - if not len(set(type(a) for a in args)) == 1: - raise ValueError(str_type_error.format([type(a) for a in args])) + elif isinstance(args[0], torch_type): return TorchBackend() elif isinstance(args[0], jax_type): return JaxBackend() @@ -66,7 +82,7 @@ def get_backend(*args): def to_numpy(*args): - """returns numpy arrays from any compatible backend""" + """Returns numpy arrays from any compatible backend""" if len(args) == 1: return get_backend(args[0]).to_numpy(args[0]) @@ -75,6 +91,13 @@ def to_numpy(*args): class Backend(): + """ + Backend abstract class. + Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend` + + - The `__name__` class attribute refers to the name of the backend. + - The `__type__` class attribute refers to the data structure used by the backend. + """ __name__ = None __type__ = None @@ -84,90 +107,426 @@ class Backend(): # convert to numpy def to_numpy(self, a): + """Returns the numpy version of a tensor""" raise NotImplementedError() # convert from numpy def from_numpy(self, a, type_as=None): + """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" raise NotImplementedError() def set_gradients(self, val, inputs, grads): - """ define the gradients for the value val wrt the inputs """ + """Define the gradients for the value val wrt the inputs """ raise NotImplementedError() def zeros(self, shape, type_as=None): + r""" + Creates a tensor full of zeros. + + This function follow the api from :any:`numpy.zeros` + + See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html + """ raise NotImplementedError() def ones(self, shape, type_as=None): + r""" + Creates a tensor full of ones. + + This function follow the api from :any:`numpy.ones` + + See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html + """ raise NotImplementedError() def arange(self, stop, start=0, step=1, type_as=None): + r""" + Returns evenly spaced values within a given interval. + + This function follow the api from :any:`numpy.arange` + + See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html + """ raise NotImplementedError() def full(self, shape, fill_value, type_as=None): + r""" + Creates a tensor with given shape, filled with given value. + + This function follow the api from :any:`numpy.full` + + See: https://numpy.org/doc/stable/reference/generated/numpy.full.html + """ raise NotImplementedError() def eye(self, N, M=None, type_as=None): + r""" + Creates the identity matrix of given size. + + This function follow the api from :any:`numpy.eye` + + See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html + """ raise NotImplementedError() def sum(self, a, axis=None, keepdims=False): + r""" + Sums tensor elements over given dimensions. + + This function follow the api from :any:`numpy.sum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html + """ raise NotImplementedError() def cumsum(self, a, axis=None): + r""" + Returns the cumulative sum of tensor elements over given dimensions. + + This function follow the api from :any:`numpy.cumsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html + """ raise NotImplementedError() def max(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follow the api from :any:`numpy.amax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html + """ raise NotImplementedError() def min(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follow the api from :any:`numpy.amin` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html + """ raise NotImplementedError() def maximum(self, a, b): + r""" + Returns element-wise maximum of array elements. + + This function follow the api from :any:`numpy.maximum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html + """ raise NotImplementedError() def minimum(self, a, b): + r""" + Returns element-wise minimum of array elements. + + This function follow the api from :any:`numpy.minimum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html + """ raise NotImplementedError() def dot(self, a, b): + r""" + Returns the dot product of two tensors. + + This function follow the api from :any:`numpy.dot` + + See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html + """ raise NotImplementedError() def abs(self, a): + r""" + Computes the absolute value element-wise. + + This function follow the api from :any:`numpy.absolute` + + See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html + """ raise NotImplementedError() def exp(self, a): + r""" + Computes the exponential value element-wise. + + This function follow the api from :any:`numpy.exp` + + See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html + """ raise NotImplementedError() def log(self, a): + r""" + Computes the natural logarithm, element-wise. + + This function follow the api from :any:`numpy.log` + + See: https://numpy.org/doc/stable/reference/generated/numpy.log.html + """ raise NotImplementedError() def sqrt(self, a): + r""" + Returns the non-ngeative square root of a tensor, element-wise. + + This function follow the api from :any:`numpy.sqrt` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html + """ + raise NotImplementedError() + + def power(self, a, exponents): + r""" + First tensor elements raised to powers from second tensor, element-wise. + + This function follow the api from :any:`numpy.power` + + See: https://numpy.org/doc/stable/reference/generated/numpy.power.html + """ raise NotImplementedError() def norm(self, a): + r""" + Computes the matrix frobenius norm. + + This function follow the api from :any:`numpy.linalg.norm` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html + """ raise NotImplementedError() def any(self, a): + r""" + Tests whether any tensor element along given dimensions evaluates to True. + + This function follow the api from :any:`numpy.any` + + See: https://numpy.org/doc/stable/reference/generated/numpy.any.html + """ raise NotImplementedError() def isnan(self, a): + r""" + Tests element-wise for NaN and returns result as a boolean tensor. + + This function follow the api from :any:`numpy.isnan` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html + """ raise NotImplementedError() def isinf(self, a): + r""" + Tests element-wise for positive or negative infinity and returns result as a boolean tensor. + + This function follow the api from :any:`numpy.isinf` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html + """ raise NotImplementedError() def einsum(self, subscripts, *operands): + r""" + Evaluates the Einstein summation convention on the operands. + + This function follow the api from :any:`numpy.einsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html + """ raise NotImplementedError() def sort(self, a, axis=-1): + r""" + Returns a sorted copy of a tensor. + + This function follow the api from :any:`numpy.sort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html + """ raise NotImplementedError() def argsort(self, a, axis=None): + r""" + Returns the indices that would sort a tensor. + + This function follow the api from :any:`numpy.argsort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html + """ + raise NotImplementedError() + + def searchsorted(self, a, v, side='left'): + r""" + Finds indices where elements should be inserted to maintain order in given tensor. + + This function follow the api from :any:`numpy.searchsorted` + + See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html + """ raise NotImplementedError() def flip(self, a, axis=None): + r""" + Reverses the order of elements in a tensor along given dimensions. + + This function follow the api from :any:`numpy.flip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html + """ + raise NotImplementedError() + + def clip(self, a, a_min, a_max): + """ + Limits the values in a tensor. + + This function follow the api from :any:`numpy.clip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html + """ + raise NotImplementedError() + + def repeat(self, a, repeats, axis=None): + r""" + Repeats elements of a tensor. + + This function follow the api from :any:`numpy.repeat` + + See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html + """ + raise NotImplementedError() + + def take_along_axis(self, arr, indices, axis): + r""" + Gathers elements of a tensor along given dimensions. + + This function follow the api from :any:`numpy.take_along_axis` + + See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html + """ + raise NotImplementedError() + + def concatenate(self, arrays, axis=0): + r""" + Joins a sequence of tensors along an existing dimension. + + This function follow the api from :any:`numpy.concatenate` + + See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html + """ + raise NotImplementedError() + + def zero_pad(self, a, pad_width): + r""" + Pads a tensor. + + This function follow the api from :any:`numpy.pad` + + See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + """ + raise NotImplementedError() + + def argmax(self, a, axis=None): + r""" + Returns the indices of the maximum values of a tensor along given dimensions. + + This function follow the api from :any:`numpy.argmax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html + """ + raise NotImplementedError() + + def mean(self, a, axis=None): + r""" + Computes the arithmetic mean of a tensor along given dimensions. + + This function follow the api from :any:`numpy.mean` + + See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html + """ + raise NotImplementedError() + + def std(self, a, axis=None): + r""" + Computes the standard deviation of a tensor along given dimensions. + + This function follow the api from :any:`numpy.std` + + See: https://numpy.org/doc/stable/reference/generated/numpy.std.html + """ + raise NotImplementedError() + + def linspace(self, start, stop, num): + r""" + Returns a specified number of evenly spaced values over a given interval. + + This function follow the api from :any:`numpy.linspace` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html + """ + raise NotImplementedError() + + def meshgrid(self, a, b): + r""" + Returns coordinate matrices from coordinate vectors (Numpy convention). + + This function follow the api from :any:`numpy.meshgrid` + + See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html + """ + raise NotImplementedError() + + def diag(self, a, k=0): + r""" + Extracts or constructs a diagonal tensor. + + This function follow the api from :any:`numpy.diag` + + See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html + """ + raise NotImplementedError() + + def unique(self, a): + r""" + Finds unique elements of given tensor. + + This function follow the api from :any:`numpy.unique` + + See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html + """ + raise NotImplementedError() + + def logsumexp(self, a, axis=None): + r""" + Computes the log of the sum of exponentials of input elements. + + This function follow the api from :any:`scipy.special.logsumexp` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html + """ + raise NotImplementedError() + + def stack(self, arrays, axis=0): + r""" + Joins a sequence of tensors along a new dimension. + + This function follow the api from :any:`numpy.stack` + + See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html + """ raise NotImplementedError() class NumpyBackend(Backend): + """ + NumPy implementation of the backend + + - `__name__` is "numpy" + - `__type__` is np.ndarray + """ __name__ = 'numpy' __type__ = np.ndarray @@ -184,7 +543,7 @@ class NumpyBackend(Backend): return a.astype(type_as.dtype) def set_gradients(self, val, inputs, grads): - # no gradients for numpy + # No gradients for numpy return val def zeros(self, shape, type_as=None): @@ -247,6 +606,9 @@ class NumpyBackend(Backend): def sqrt(self, a): return np.sqrt(a) + def power(self, a, exponents): + return np.power(a, exponents) + def norm(self, a): return np.sqrt(np.sum(np.square(a))) @@ -268,11 +630,70 @@ class NumpyBackend(Backend): def argsort(self, a, axis=-1): return np.argsort(a, axis) + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return np.searchsorted(a, v, side) + else: + # this is a not very efficient way to make numpy + # searchsorted work on 2d arrays + ret = np.empty(v.shape, dtype=int) + for i in range(a.shape[0]): + ret[i, :] = np.searchsorted(a[i, :], v[i, :], side) + return ret + def flip(self, a, axis=None): return np.flip(a, axis) + def clip(self, a, a_min, a_max): + return np.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return np.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return np.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return np.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return np.pad(a, pad_width) + + def argmax(self, a, axis=None): + return np.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return np.mean(a, axis=axis) + + def std(self, a, axis=None): + return np.std(a, axis=axis) + + def linspace(self, start, stop, num): + return np.linspace(start, stop, num) + + def meshgrid(self, a, b): + return np.meshgrid(a, b) + + def diag(self, a, k=0): + return np.diag(a, k) + + def unique(self, a): + return np.unique(a) + + def logsumexp(self, a, axis=None): + return scipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return np.stack(arrays, axis) + class JaxBackend(Backend): + """ + JAX implementation of the backend + + - `__name__` is "jax" + - `__type__` is jax.numpy.ndarray + """ __name__ = 'jax' __type__ = jax_type @@ -359,6 +780,9 @@ class JaxBackend(Backend): def sqrt(self, a): return jnp.sqrt(a) + def power(self, a, exponents): + return jnp.power(a, exponents) + def norm(self, a): return jnp.sqrt(jnp.sum(jnp.square(a))) @@ -380,11 +804,67 @@ class JaxBackend(Backend): def argsort(self, a, axis=-1): return jnp.argsort(a, axis) + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return jnp.searchsorted(a, v, side) + else: + # this is a not very efficient way to make jax numpy + # searchsorted work on 2d arrays + return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])]) + def flip(self, a, axis=None): return jnp.flip(a, axis) + def clip(self, a, a_min, a_max): + return jnp.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return jnp.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return jnp.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return jnp.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return jnp.pad(a, pad_width) + + def argmax(self, a, axis=None): + return jnp.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return jnp.mean(a, axis=axis) + + def std(self, a, axis=None): + return jnp.std(a, axis=axis) + + def linspace(self, start, stop, num): + return jnp.linspace(start, stop, num) + + def meshgrid(self, a, b): + return jnp.meshgrid(a, b) + + def diag(self, a, k=0): + return jnp.diag(a, k) + + def unique(self, a): + return jnp.unique(a) + + def logsumexp(self, a, axis=None): + return jscipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return jnp.stack(arrays, axis) + class TorchBackend(Backend): + """ + PyTorch implementation of the backend + + - `__name__` is "torch" + - `__type__` is torch.Tensor + """ __name__ = 'torch' __type__ = torch_type @@ -487,22 +967,23 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - return torch.maximum(a, b) + if torch.__version__ >= '1.7.0': + return torch.maximum(a, b) + else: + return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] def minimum(self, a, b): if isinstance(a, int) or isinstance(a, float): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - return torch.minimum(a, b) + if torch.__version__ >= '1.7.0': + return torch.minimum(a, b) + else: + return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] def dot(self, a, b): - if len(a.shape) == len(b.shape) == 1: - return torch.dot(a, b) - elif len(a.shape) == 2 and len(b.shape) == 1: - return torch.mv(a, b) - else: - return torch.mm(a, b) + return torch.matmul(a, b) def abs(self, a): return torch.abs(a) @@ -516,6 +997,9 @@ class TorchBackend(Backend): def sqrt(self, a): return torch.sqrt(a) + def power(self, a, exponents): + return torch.pow(a, exponents) + def norm(self, a): return torch.sqrt(torch.sum(torch.square(a))) @@ -539,6 +1023,10 @@ class TorchBackend(Backend): sorted, indices = torch.sort(a, dim=axis) return indices + def searchsorted(self, a, v, side='left'): + right = (side != 'left') + return torch.searchsorted(a, v, right=right) + def flip(self, a, axis=None): if axis is None: return torch.flip(a, tuple(i for i in range(len(a.shape)))) @@ -546,3 +1034,60 @@ class TorchBackend(Backend): return torch.flip(a, (axis,)) else: return torch.flip(a, dims=axis) + + def clip(self, a, a_min, a_max): + return torch.clamp(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return torch.repeat_interleave(a, repeats, dim=axis) + + def take_along_axis(self, arr, indices, axis): + return torch.gather(arr, axis, indices) + + def concatenate(self, arrays, axis=0): + return torch.cat(arrays, dim=axis) + + def zero_pad(self, a, pad_width): + from torch.nn.functional import pad + # pad_width is an array of ndim tuples indicating how many 0 before and after + # we need to add. We first need to make it compliant with torch syntax, that + # starts with the last dim, then second last, etc. + how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) + return pad(a, how_pad) + + def argmax(self, a, axis=None): + return torch.argmax(a, dim=axis) + + def mean(self, a, axis=None): + if axis is not None: + return torch.mean(a, dim=axis) + else: + return torch.mean(a) + + def std(self, a, axis=None): + if axis is not None: + return torch.std(a, dim=axis, unbiased=False) + else: + return torch.std(a, unbiased=False) + + def linspace(self, start, stop, num): + return torch.linspace(start, stop, num, dtype=torch.float64) + + def meshgrid(self, a, b): + X, Y = torch.meshgrid(a, b) + return X.T, Y.T + + def diag(self, a, k=0): + return torch.diag(a, diagonal=k) + + def unique(self, a): + return torch.unique(a) + + def logsumexp(self, a, axis=None): + if axis is not None: + return torch.logsumexp(a, dim=axis) + else: + return torch.logsumexp(a, dim=tuple(range(len(a.shape)))) + + def stack(self, arrays, axis=0): + return torch.stack(arrays, dim=axis) |