summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py118
1 files changed, 78 insertions, 40 deletions
diff --git a/ot/backend.py b/ot/backend.py
index a4a4757..876b96a 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -123,7 +123,7 @@ class Backend():
r"""
Creates a tensor full of zeros.
- This function follow the api from :any:`numpy.zeros`
+ This function follows the api from :any:`numpy.zeros`
See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html
"""
@@ -133,7 +133,7 @@ class Backend():
r"""
Creates a tensor full of ones.
- This function follow the api from :any:`numpy.ones`
+ This function follows the api from :any:`numpy.ones`
See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html
"""
@@ -143,7 +143,7 @@ class Backend():
r"""
Returns evenly spaced values within a given interval.
- This function follow the api from :any:`numpy.arange`
+ This function follows the api from :any:`numpy.arange`
See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html
"""
@@ -153,7 +153,7 @@ class Backend():
r"""
Creates a tensor with given shape, filled with given value.
- This function follow the api from :any:`numpy.full`
+ This function follows the api from :any:`numpy.full`
See: https://numpy.org/doc/stable/reference/generated/numpy.full.html
"""
@@ -163,7 +163,7 @@ class Backend():
r"""
Creates the identity matrix of given size.
- This function follow the api from :any:`numpy.eye`
+ This function follows the api from :any:`numpy.eye`
See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html
"""
@@ -173,7 +173,7 @@ class Backend():
r"""
Sums tensor elements over given dimensions.
- This function follow the api from :any:`numpy.sum`
+ This function follows the api from :any:`numpy.sum`
See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html
"""
@@ -183,7 +183,7 @@ class Backend():
r"""
Returns the cumulative sum of tensor elements over given dimensions.
- This function follow the api from :any:`numpy.cumsum`
+ This function follows the api from :any:`numpy.cumsum`
See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
"""
@@ -193,7 +193,7 @@ class Backend():
r"""
Returns the maximum of an array or maximum along given dimensions.
- This function follow the api from :any:`numpy.amax`
+ This function follows the api from :any:`numpy.amax`
See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html
"""
@@ -203,7 +203,7 @@ class Backend():
r"""
Returns the maximum of an array or maximum along given dimensions.
- This function follow the api from :any:`numpy.amin`
+ This function follows the api from :any:`numpy.amin`
See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html
"""
@@ -213,7 +213,7 @@ class Backend():
r"""
Returns element-wise maximum of array elements.
- This function follow the api from :any:`numpy.maximum`
+ This function follows the api from :any:`numpy.maximum`
See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
"""
@@ -223,7 +223,7 @@ class Backend():
r"""
Returns element-wise minimum of array elements.
- This function follow the api from :any:`numpy.minimum`
+ This function follows the api from :any:`numpy.minimum`
See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
"""
@@ -233,7 +233,7 @@ class Backend():
r"""
Returns the dot product of two tensors.
- This function follow the api from :any:`numpy.dot`
+ This function follows the api from :any:`numpy.dot`
See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
"""
@@ -243,7 +243,7 @@ class Backend():
r"""
Computes the absolute value element-wise.
- This function follow the api from :any:`numpy.absolute`
+ This function follows the api from :any:`numpy.absolute`
See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html
"""
@@ -253,7 +253,7 @@ class Backend():
r"""
Computes the exponential value element-wise.
- This function follow the api from :any:`numpy.exp`
+ This function follows the api from :any:`numpy.exp`
See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html
"""
@@ -263,7 +263,7 @@ class Backend():
r"""
Computes the natural logarithm, element-wise.
- This function follow the api from :any:`numpy.log`
+ This function follows the api from :any:`numpy.log`
See: https://numpy.org/doc/stable/reference/generated/numpy.log.html
"""
@@ -273,7 +273,7 @@ class Backend():
r"""
Returns the non-ngeative square root of a tensor, element-wise.
- This function follow the api from :any:`numpy.sqrt`
+ This function follows the api from :any:`numpy.sqrt`
See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html
"""
@@ -283,7 +283,7 @@ class Backend():
r"""
First tensor elements raised to powers from second tensor, element-wise.
- This function follow the api from :any:`numpy.power`
+ This function follows the api from :any:`numpy.power`
See: https://numpy.org/doc/stable/reference/generated/numpy.power.html
"""
@@ -293,7 +293,7 @@ class Backend():
r"""
Computes the matrix frobenius norm.
- This function follow the api from :any:`numpy.linalg.norm`
+ This function follows the api from :any:`numpy.linalg.norm`
See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
"""
@@ -303,7 +303,7 @@ class Backend():
r"""
Tests whether any tensor element along given dimensions evaluates to True.
- This function follow the api from :any:`numpy.any`
+ This function follows the api from :any:`numpy.any`
See: https://numpy.org/doc/stable/reference/generated/numpy.any.html
"""
@@ -313,7 +313,7 @@ class Backend():
r"""
Tests element-wise for NaN and returns result as a boolean tensor.
- This function follow the api from :any:`numpy.isnan`
+ This function follows the api from :any:`numpy.isnan`
See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html
"""
@@ -323,7 +323,7 @@ class Backend():
r"""
Tests element-wise for positive or negative infinity and returns result as a boolean tensor.
- This function follow the api from :any:`numpy.isinf`
+ This function follows the api from :any:`numpy.isinf`
See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html
"""
@@ -333,7 +333,7 @@ class Backend():
r"""
Evaluates the Einstein summation convention on the operands.
- This function follow the api from :any:`numpy.einsum`
+ This function follows the api from :any:`numpy.einsum`
See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
"""
@@ -343,7 +343,7 @@ class Backend():
r"""
Returns a sorted copy of a tensor.
- This function follow the api from :any:`numpy.sort`
+ This function follows the api from :any:`numpy.sort`
See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html
"""
@@ -353,7 +353,7 @@ class Backend():
r"""
Returns the indices that would sort a tensor.
- This function follow the api from :any:`numpy.argsort`
+ This function follows the api from :any:`numpy.argsort`
See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
"""
@@ -363,7 +363,7 @@ class Backend():
r"""
Finds indices where elements should be inserted to maintain order in given tensor.
- This function follow the api from :any:`numpy.searchsorted`
+ This function follows the api from :any:`numpy.searchsorted`
See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
"""
@@ -373,7 +373,7 @@ class Backend():
r"""
Reverses the order of elements in a tensor along given dimensions.
- This function follow the api from :any:`numpy.flip`
+ This function follows the api from :any:`numpy.flip`
See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html
"""
@@ -383,7 +383,7 @@ class Backend():
"""
Limits the values in a tensor.
- This function follow the api from :any:`numpy.clip`
+ This function follows the api from :any:`numpy.clip`
See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html
"""
@@ -393,7 +393,7 @@ class Backend():
r"""
Repeats elements of a tensor.
- This function follow the api from :any:`numpy.repeat`
+ This function follows the api from :any:`numpy.repeat`
See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html
"""
@@ -403,7 +403,7 @@ class Backend():
r"""
Gathers elements of a tensor along given dimensions.
- This function follow the api from :any:`numpy.take_along_axis`
+ This function follows the api from :any:`numpy.take_along_axis`
See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
"""
@@ -413,7 +413,7 @@ class Backend():
r"""
Joins a sequence of tensors along an existing dimension.
- This function follow the api from :any:`numpy.concatenate`
+ This function follows the api from :any:`numpy.concatenate`
See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html
"""
@@ -423,7 +423,7 @@ class Backend():
r"""
Pads a tensor.
- This function follow the api from :any:`numpy.pad`
+ This function follows the api from :any:`numpy.pad`
See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
"""
@@ -433,7 +433,7 @@ class Backend():
r"""
Returns the indices of the maximum values of a tensor along given dimensions.
- This function follow the api from :any:`numpy.argmax`
+ This function follows the api from :any:`numpy.argmax`
See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
"""
@@ -443,7 +443,7 @@ class Backend():
r"""
Computes the arithmetic mean of a tensor along given dimensions.
- This function follow the api from :any:`numpy.mean`
+ This function follows the api from :any:`numpy.mean`
See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html
"""
@@ -453,7 +453,7 @@ class Backend():
r"""
Computes the standard deviation of a tensor along given dimensions.
- This function follow the api from :any:`numpy.std`
+ This function follows the api from :any:`numpy.std`
See: https://numpy.org/doc/stable/reference/generated/numpy.std.html
"""
@@ -463,7 +463,7 @@ class Backend():
r"""
Returns a specified number of evenly spaced values over a given interval.
- This function follow the api from :any:`numpy.linspace`
+ This function follows the api from :any:`numpy.linspace`
See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html
"""
@@ -473,7 +473,7 @@ class Backend():
r"""
Returns coordinate matrices from coordinate vectors (Numpy convention).
- This function follow the api from :any:`numpy.meshgrid`
+ This function follows the api from :any:`numpy.meshgrid`
See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
"""
@@ -483,7 +483,7 @@ class Backend():
r"""
Extracts or constructs a diagonal tensor.
- This function follow the api from :any:`numpy.diag`
+ This function follows the api from :any:`numpy.diag`
See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html
"""
@@ -493,7 +493,7 @@ class Backend():
r"""
Finds unique elements of given tensor.
- This function follow the api from :any:`numpy.unique`
+ This function follows the api from :any:`numpy.unique`
See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html
"""
@@ -503,7 +503,7 @@ class Backend():
r"""
Computes the log of the sum of exponentials of input elements.
- This function follow the api from :any:`scipy.special.logsumexp`
+ This function follows the api from :any:`scipy.special.logsumexp`
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
"""
@@ -513,12 +513,32 @@ class Backend():
r"""
Joins a sequence of tensors along a new dimension.
- This function follow the api from :any:`numpy.stack`
+ This function follows the api from :any:`numpy.stack`
See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html
"""
raise NotImplementedError()
+ def outer(self, a, b):
+ r"""
+ Computes the outer product between two vectors.
+
+ This function follows the api from :any:`numpy.outer`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html
+ """
+ raise NotImplementedError()
+
+ def reshape(self, a, shape):
+ r"""
+ Gives a new shape to a tensor without changing its data.
+
+ This function follows the api from :any:`numpy.reshape`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -644,6 +664,9 @@ class NumpyBackend(Backend):
def flip(self, a, axis=None):
return np.flip(a, axis)
+ def outer(self, a, b):
+ return np.outer(a, b)
+
def clip(self, a, a_min, a_max):
return np.clip(a, a_min, a_max)
@@ -686,6 +709,9 @@ class NumpyBackend(Backend):
def stack(self, arrays, axis=0):
return np.stack(arrays, axis)
+ def reshape(self, a, shape):
+ return np.reshape(a, shape)
+
class JaxBackend(Backend):
"""
@@ -815,6 +841,9 @@ class JaxBackend(Backend):
def flip(self, a, axis=None):
return jnp.flip(a, axis)
+ def outer(self, a, b):
+ return jnp.outer(a, b)
+
def clip(self, a, a_min, a_max):
return jnp.clip(a, a_min, a_max)
@@ -857,6 +886,9 @@ class JaxBackend(Backend):
def stack(self, arrays, axis=0):
return jnp.stack(arrays, axis)
+ def reshape(self, a, shape):
+ return jnp.reshape(a, shape)
+
class TorchBackend(Backend):
"""
@@ -1035,6 +1067,9 @@ class TorchBackend(Backend):
else:
return torch.flip(a, dims=axis)
+ def outer(self, a, b):
+ return torch.outer(a, b)
+
def clip(self, a, a_min, a_max):
return torch.clamp(a, a_min, a_max)
@@ -1091,3 +1126,6 @@ class TorchBackend(Backend):
def stack(self, arrays, axis=0):
return torch.stack(arrays, dim=axis)
+
+ def reshape(self, a, shape):
+ return torch.reshape(a, shape)