summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 11:36:21 +0200
committerGitHub <noreply@github.com>2021-10-25 11:36:21 +0200
commit7a65086dd340265d0223eb8ffb5c9a5152a82dff (patch)
tree300f4a1cd645516fba1e440691fe48830d781b5c /ot/backend.py
parent7af8c2147d61349f4d99ca33318a8a125e4569aa (diff)
[MRG] Bregman backend (#280)
* Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py581
1 files changed, 563 insertions, 18 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 2ed40af..a4a4757 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1,6 +1,22 @@
# -*- coding: utf-8 -*-
"""
Multi-lib backend for POT
+
+The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
+or Jax, POT code should work nonetheless.
+To achieve that, POT provides backend classes which implements functions in their respective backend
+imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
+
+Examples
+--------
+
+>>> from ot.utils import list_to_array
+>>> from ot.backend import get_backend
+>>> def f(a, b): # the function does not know which backend to use
+... a, b = list_to_array(a, b) # if a list in given, make it an array
+... nx = get_backend(a, b) # infer the backend from the arguments
+... c = nx.dot(a, b) # now use the backend to do any calculation
+... return c
"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
@@ -9,6 +25,7 @@ Multi-lib backend for POT
# License: MIT License
import numpy as np
+import scipy.special as scipy
try:
import torch
@@ -20,6 +37,7 @@ except ImportError:
try:
import jax
import jax.numpy as jnp
+ import jax.scipy.special as jscipy
jax_type = jax.numpy.ndarray
except ImportError:
jax = False
@@ -29,7 +47,7 @@ str_type_error = "All array should be from the same type/backend. Current types
def get_backend_list():
- """ returns the list of available backends)"""
+ """Returns the list of available backends"""
lst = [NumpyBackend(), ]
if torch:
@@ -42,7 +60,7 @@ def get_backend_list():
def get_backend(*args):
- """returns the proper backend for a list of input arrays
+ """Returns the proper backend for a list of input arrays
Also raises TypeError if all arrays are not from the same backend
"""
@@ -50,14 +68,12 @@ def get_backend(*args):
if not len(args) > 0:
raise ValueError(" The function takes at least one parameter")
# check all same type
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
if isinstance(args[0], np.ndarray):
- if not len(set(type(a) for a in args)) == 1:
- raise ValueError(str_type_error.format([type(a) for a in args]))
return NumpyBackend()
- elif torch and isinstance(args[0], torch_type):
- if not len(set(type(a) for a in args)) == 1:
- raise ValueError(str_type_error.format([type(a) for a in args]))
+ elif isinstance(args[0], torch_type):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
@@ -66,7 +82,7 @@ def get_backend(*args):
def to_numpy(*args):
- """returns numpy arrays from any compatible backend"""
+ """Returns numpy arrays from any compatible backend"""
if len(args) == 1:
return get_backend(args[0]).to_numpy(args[0])
@@ -75,6 +91,13 @@ def to_numpy(*args):
class Backend():
+ """
+ Backend abstract class.
+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
+
+ - The `__name__` class attribute refers to the name of the backend.
+ - The `__type__` class attribute refers to the data structure used by the backend.
+ """
__name__ = None
__type__ = None
@@ -84,90 +107,426 @@ class Backend():
# convert to numpy
def to_numpy(self, a):
+ """Returns the numpy version of a tensor"""
raise NotImplementedError()
# convert from numpy
def from_numpy(self, a, type_as=None):
+ """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
raise NotImplementedError()
def set_gradients(self, val, inputs, grads):
- """ define the gradients for the value val wrt the inputs """
+ """Define the gradients for the value val wrt the inputs """
raise NotImplementedError()
def zeros(self, shape, type_as=None):
+ r"""
+ Creates a tensor full of zeros.
+
+ This function follow the api from :any:`numpy.zeros`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html
+ """
raise NotImplementedError()
def ones(self, shape, type_as=None):
+ r"""
+ Creates a tensor full of ones.
+
+ This function follow the api from :any:`numpy.ones`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html
+ """
raise NotImplementedError()
def arange(self, stop, start=0, step=1, type_as=None):
+ r"""
+ Returns evenly spaced values within a given interval.
+
+ This function follow the api from :any:`numpy.arange`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html
+ """
raise NotImplementedError()
def full(self, shape, fill_value, type_as=None):
+ r"""
+ Creates a tensor with given shape, filled with given value.
+
+ This function follow the api from :any:`numpy.full`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.full.html
+ """
raise NotImplementedError()
def eye(self, N, M=None, type_as=None):
+ r"""
+ Creates the identity matrix of given size.
+
+ This function follow the api from :any:`numpy.eye`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html
+ """
raise NotImplementedError()
def sum(self, a, axis=None, keepdims=False):
+ r"""
+ Sums tensor elements over given dimensions.
+
+ This function follow the api from :any:`numpy.sum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html
+ """
raise NotImplementedError()
def cumsum(self, a, axis=None):
+ r"""
+ Returns the cumulative sum of tensor elements over given dimensions.
+
+ This function follow the api from :any:`numpy.cumsum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
+ """
raise NotImplementedError()
def max(self, a, axis=None, keepdims=False):
+ r"""
+ Returns the maximum of an array or maximum along given dimensions.
+
+ This function follow the api from :any:`numpy.amax`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html
+ """
raise NotImplementedError()
def min(self, a, axis=None, keepdims=False):
+ r"""
+ Returns the maximum of an array or maximum along given dimensions.
+
+ This function follow the api from :any:`numpy.amin`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html
+ """
raise NotImplementedError()
def maximum(self, a, b):
+ r"""
+ Returns element-wise maximum of array elements.
+
+ This function follow the api from :any:`numpy.maximum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
+ """
raise NotImplementedError()
def minimum(self, a, b):
+ r"""
+ Returns element-wise minimum of array elements.
+
+ This function follow the api from :any:`numpy.minimum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
+ """
raise NotImplementedError()
def dot(self, a, b):
+ r"""
+ Returns the dot product of two tensors.
+
+ This function follow the api from :any:`numpy.dot`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
+ """
raise NotImplementedError()
def abs(self, a):
+ r"""
+ Computes the absolute value element-wise.
+
+ This function follow the api from :any:`numpy.absolute`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html
+ """
raise NotImplementedError()
def exp(self, a):
+ r"""
+ Computes the exponential value element-wise.
+
+ This function follow the api from :any:`numpy.exp`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html
+ """
raise NotImplementedError()
def log(self, a):
+ r"""
+ Computes the natural logarithm, element-wise.
+
+ This function follow the api from :any:`numpy.log`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.log.html
+ """
raise NotImplementedError()
def sqrt(self, a):
+ r"""
+ Returns the non-ngeative square root of a tensor, element-wise.
+
+ This function follow the api from :any:`numpy.sqrt`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html
+ """
+ raise NotImplementedError()
+
+ def power(self, a, exponents):
+ r"""
+ First tensor elements raised to powers from second tensor, element-wise.
+
+ This function follow the api from :any:`numpy.power`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.power.html
+ """
raise NotImplementedError()
def norm(self, a):
+ r"""
+ Computes the matrix frobenius norm.
+
+ This function follow the api from :any:`numpy.linalg.norm`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
+ """
raise NotImplementedError()
def any(self, a):
+ r"""
+ Tests whether any tensor element along given dimensions evaluates to True.
+
+ This function follow the api from :any:`numpy.any`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.any.html
+ """
raise NotImplementedError()
def isnan(self, a):
+ r"""
+ Tests element-wise for NaN and returns result as a boolean tensor.
+
+ This function follow the api from :any:`numpy.isnan`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html
+ """
raise NotImplementedError()
def isinf(self, a):
+ r"""
+ Tests element-wise for positive or negative infinity and returns result as a boolean tensor.
+
+ This function follow the api from :any:`numpy.isinf`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html
+ """
raise NotImplementedError()
def einsum(self, subscripts, *operands):
+ r"""
+ Evaluates the Einstein summation convention on the operands.
+
+ This function follow the api from :any:`numpy.einsum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
+ """
raise NotImplementedError()
def sort(self, a, axis=-1):
+ r"""
+ Returns a sorted copy of a tensor.
+
+ This function follow the api from :any:`numpy.sort`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html
+ """
raise NotImplementedError()
def argsort(self, a, axis=None):
+ r"""
+ Returns the indices that would sort a tensor.
+
+ This function follow the api from :any:`numpy.argsort`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
+ """
+ raise NotImplementedError()
+
+ def searchsorted(self, a, v, side='left'):
+ r"""
+ Finds indices where elements should be inserted to maintain order in given tensor.
+
+ This function follow the api from :any:`numpy.searchsorted`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
+ """
raise NotImplementedError()
def flip(self, a, axis=None):
+ r"""
+ Reverses the order of elements in a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.flip`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html
+ """
+ raise NotImplementedError()
+
+ def clip(self, a, a_min, a_max):
+ """
+ Limits the values in a tensor.
+
+ This function follow the api from :any:`numpy.clip`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html
+ """
+ raise NotImplementedError()
+
+ def repeat(self, a, repeats, axis=None):
+ r"""
+ Repeats elements of a tensor.
+
+ This function follow the api from :any:`numpy.repeat`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html
+ """
+ raise NotImplementedError()
+
+ def take_along_axis(self, arr, indices, axis):
+ r"""
+ Gathers elements of a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.take_along_axis`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
+ """
+ raise NotImplementedError()
+
+ def concatenate(self, arrays, axis=0):
+ r"""
+ Joins a sequence of tensors along an existing dimension.
+
+ This function follow the api from :any:`numpy.concatenate`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html
+ """
+ raise NotImplementedError()
+
+ def zero_pad(self, a, pad_width):
+ r"""
+ Pads a tensor.
+
+ This function follow the api from :any:`numpy.pad`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
+ """
+ raise NotImplementedError()
+
+ def argmax(self, a, axis=None):
+ r"""
+ Returns the indices of the maximum values of a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.argmax`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
+ """
+ raise NotImplementedError()
+
+ def mean(self, a, axis=None):
+ r"""
+ Computes the arithmetic mean of a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.mean`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html
+ """
+ raise NotImplementedError()
+
+ def std(self, a, axis=None):
+ r"""
+ Computes the standard deviation of a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.std`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.std.html
+ """
+ raise NotImplementedError()
+
+ def linspace(self, start, stop, num):
+ r"""
+ Returns a specified number of evenly spaced values over a given interval.
+
+ This function follow the api from :any:`numpy.linspace`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html
+ """
+ raise NotImplementedError()
+
+ def meshgrid(self, a, b):
+ r"""
+ Returns coordinate matrices from coordinate vectors (Numpy convention).
+
+ This function follow the api from :any:`numpy.meshgrid`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
+ """
+ raise NotImplementedError()
+
+ def diag(self, a, k=0):
+ r"""
+ Extracts or constructs a diagonal tensor.
+
+ This function follow the api from :any:`numpy.diag`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html
+ """
+ raise NotImplementedError()
+
+ def unique(self, a):
+ r"""
+ Finds unique elements of given tensor.
+
+ This function follow the api from :any:`numpy.unique`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html
+ """
+ raise NotImplementedError()
+
+ def logsumexp(self, a, axis=None):
+ r"""
+ Computes the log of the sum of exponentials of input elements.
+
+ This function follow the api from :any:`scipy.special.logsumexp`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
+ """
+ raise NotImplementedError()
+
+ def stack(self, arrays, axis=0):
+ r"""
+ Joins a sequence of tensors along a new dimension.
+
+ This function follow the api from :any:`numpy.stack`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html
+ """
raise NotImplementedError()
class NumpyBackend(Backend):
+ """
+ NumPy implementation of the backend
+
+ - `__name__` is "numpy"
+ - `__type__` is np.ndarray
+ """
__name__ = 'numpy'
__type__ = np.ndarray
@@ -184,7 +543,7 @@ class NumpyBackend(Backend):
return a.astype(type_as.dtype)
def set_gradients(self, val, inputs, grads):
- # no gradients for numpy
+ # No gradients for numpy
return val
def zeros(self, shape, type_as=None):
@@ -247,6 +606,9 @@ class NumpyBackend(Backend):
def sqrt(self, a):
return np.sqrt(a)
+ def power(self, a, exponents):
+ return np.power(a, exponents)
+
def norm(self, a):
return np.sqrt(np.sum(np.square(a)))
@@ -268,11 +630,70 @@ class NumpyBackend(Backend):
def argsort(self, a, axis=-1):
return np.argsort(a, axis)
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return np.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make numpy
+ # searchsorted work on 2d arrays
+ ret = np.empty(v.shape, dtype=int)
+ for i in range(a.shape[0]):
+ ret[i, :] = np.searchsorted(a[i, :], v[i, :], side)
+ return ret
+
def flip(self, a, axis=None):
return np.flip(a, axis)
+ def clip(self, a, a_min, a_max):
+ return np.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return np.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return np.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return np.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return np.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return np.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return np.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return np.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return np.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return np.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return np.diag(a, k)
+
+ def unique(self, a):
+ return np.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ return scipy.logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return np.stack(arrays, axis)
+
class JaxBackend(Backend):
+ """
+ JAX implementation of the backend
+
+ - `__name__` is "jax"
+ - `__type__` is jax.numpy.ndarray
+ """
__name__ = 'jax'
__type__ = jax_type
@@ -359,6 +780,9 @@ class JaxBackend(Backend):
def sqrt(self, a):
return jnp.sqrt(a)
+ def power(self, a, exponents):
+ return jnp.power(a, exponents)
+
def norm(self, a):
return jnp.sqrt(jnp.sum(jnp.square(a)))
@@ -380,11 +804,67 @@ class JaxBackend(Backend):
def argsort(self, a, axis=-1):
return jnp.argsort(a, axis)
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return jnp.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make jax numpy
+ # searchsorted work on 2d arrays
+ return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])])
+
def flip(self, a, axis=None):
return jnp.flip(a, axis)
+ def clip(self, a, a_min, a_max):
+ return jnp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return jnp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return jnp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return jnp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return jnp.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return jnp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return jnp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return jnp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return jnp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return jnp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return jnp.diag(a, k)
+
+ def unique(self, a):
+ return jnp.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ return jscipy.logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return jnp.stack(arrays, axis)
+
class TorchBackend(Backend):
+ """
+ PyTorch implementation of the backend
+
+ - `__name__` is "torch"
+ - `__type__` is torch.Tensor
+ """
__name__ = 'torch'
__type__ = torch_type
@@ -487,22 +967,23 @@ class TorchBackend(Backend):
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
if isinstance(b, int) or isinstance(b, float):
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
- return torch.maximum(a, b)
+ if torch.__version__ >= '1.7.0':
+ return torch.maximum(a, b)
+ else:
+ return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
def minimum(self, a, b):
if isinstance(a, int) or isinstance(a, float):
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
if isinstance(b, int) or isinstance(b, float):
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
- return torch.minimum(a, b)
+ if torch.__version__ >= '1.7.0':
+ return torch.minimum(a, b)
+ else:
+ return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
def dot(self, a, b):
- if len(a.shape) == len(b.shape) == 1:
- return torch.dot(a, b)
- elif len(a.shape) == 2 and len(b.shape) == 1:
- return torch.mv(a, b)
- else:
- return torch.mm(a, b)
+ return torch.matmul(a, b)
def abs(self, a):
return torch.abs(a)
@@ -516,6 +997,9 @@ class TorchBackend(Backend):
def sqrt(self, a):
return torch.sqrt(a)
+ def power(self, a, exponents):
+ return torch.pow(a, exponents)
+
def norm(self, a):
return torch.sqrt(torch.sum(torch.square(a)))
@@ -539,6 +1023,10 @@ class TorchBackend(Backend):
sorted, indices = torch.sort(a, dim=axis)
return indices
+ def searchsorted(self, a, v, side='left'):
+ right = (side != 'left')
+ return torch.searchsorted(a, v, right=right)
+
def flip(self, a, axis=None):
if axis is None:
return torch.flip(a, tuple(i for i in range(len(a.shape))))
@@ -546,3 +1034,60 @@ class TorchBackend(Backend):
return torch.flip(a, (axis,))
else:
return torch.flip(a, dims=axis)
+
+ def clip(self, a, a_min, a_max):
+ return torch.clamp(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return torch.repeat_interleave(a, repeats, dim=axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return torch.gather(arr, axis, indices)
+
+ def concatenate(self, arrays, axis=0):
+ return torch.cat(arrays, dim=axis)
+
+ def zero_pad(self, a, pad_width):
+ from torch.nn.functional import pad
+ # pad_width is an array of ndim tuples indicating how many 0 before and after
+ # we need to add. We first need to make it compliant with torch syntax, that
+ # starts with the last dim, then second last, etc.
+ how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
+ return pad(a, how_pad)
+
+ def argmax(self, a, axis=None):
+ return torch.argmax(a, dim=axis)
+
+ def mean(self, a, axis=None):
+ if axis is not None:
+ return torch.mean(a, dim=axis)
+ else:
+ return torch.mean(a)
+
+ def std(self, a, axis=None):
+ if axis is not None:
+ return torch.std(a, dim=axis, unbiased=False)
+ else:
+ return torch.std(a, unbiased=False)
+
+ def linspace(self, start, stop, num):
+ return torch.linspace(start, stop, num, dtype=torch.float64)
+
+ def meshgrid(self, a, b):
+ X, Y = torch.meshgrid(a, b)
+ return X.T, Y.T
+
+ def diag(self, a, k=0):
+ return torch.diag(a, diagonal=k)
+
+ def unique(self, a):
+ return torch.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ if axis is not None:
+ return torch.logsumexp(a, dim=axis)
+ else:
+ return torch.logsumexp(a, dim=tuple(range(len(a.shape))))
+
+ def stack(self, arrays, axis=0):
+ return torch.stack(arrays, dim=axis)