diff options
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 118 |
1 files changed, 78 insertions, 40 deletions
diff --git a/ot/backend.py b/ot/backend.py index a4a4757..876b96a 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -123,7 +123,7 @@ class Backend(): r""" Creates a tensor full of zeros. - This function follow the api from :any:`numpy.zeros` + This function follows the api from :any:`numpy.zeros` See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html """ @@ -133,7 +133,7 @@ class Backend(): r""" Creates a tensor full of ones. - This function follow the api from :any:`numpy.ones` + This function follows the api from :any:`numpy.ones` See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html """ @@ -143,7 +143,7 @@ class Backend(): r""" Returns evenly spaced values within a given interval. - This function follow the api from :any:`numpy.arange` + This function follows the api from :any:`numpy.arange` See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html """ @@ -153,7 +153,7 @@ class Backend(): r""" Creates a tensor with given shape, filled with given value. - This function follow the api from :any:`numpy.full` + This function follows the api from :any:`numpy.full` See: https://numpy.org/doc/stable/reference/generated/numpy.full.html """ @@ -163,7 +163,7 @@ class Backend(): r""" Creates the identity matrix of given size. - This function follow the api from :any:`numpy.eye` + This function follows the api from :any:`numpy.eye` See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html """ @@ -173,7 +173,7 @@ class Backend(): r""" Sums tensor elements over given dimensions. - This function follow the api from :any:`numpy.sum` + This function follows the api from :any:`numpy.sum` See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html """ @@ -183,7 +183,7 @@ class Backend(): r""" Returns the cumulative sum of tensor elements over given dimensions. - This function follow the api from :any:`numpy.cumsum` + This function follows the api from :any:`numpy.cumsum` See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html """ @@ -193,7 +193,7 @@ class Backend(): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amax` + This function follows the api from :any:`numpy.amax` See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html """ @@ -203,7 +203,7 @@ class Backend(): r""" Returns the maximum of an array or maximum along given dimensions. - This function follow the api from :any:`numpy.amin` + This function follows the api from :any:`numpy.amin` See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html """ @@ -213,7 +213,7 @@ class Backend(): r""" Returns element-wise maximum of array elements. - This function follow the api from :any:`numpy.maximum` + This function follows the api from :any:`numpy.maximum` See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html """ @@ -223,7 +223,7 @@ class Backend(): r""" Returns element-wise minimum of array elements. - This function follow the api from :any:`numpy.minimum` + This function follows the api from :any:`numpy.minimum` See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html """ @@ -233,7 +233,7 @@ class Backend(): r""" Returns the dot product of two tensors. - This function follow the api from :any:`numpy.dot` + This function follows the api from :any:`numpy.dot` See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html """ @@ -243,7 +243,7 @@ class Backend(): r""" Computes the absolute value element-wise. - This function follow the api from :any:`numpy.absolute` + This function follows the api from :any:`numpy.absolute` See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html """ @@ -253,7 +253,7 @@ class Backend(): r""" Computes the exponential value element-wise. - This function follow the api from :any:`numpy.exp` + This function follows the api from :any:`numpy.exp` See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html """ @@ -263,7 +263,7 @@ class Backend(): r""" Computes the natural logarithm, element-wise. - This function follow the api from :any:`numpy.log` + This function follows the api from :any:`numpy.log` See: https://numpy.org/doc/stable/reference/generated/numpy.log.html """ @@ -273,7 +273,7 @@ class Backend(): r""" Returns the non-ngeative square root of a tensor, element-wise. - This function follow the api from :any:`numpy.sqrt` + This function follows the api from :any:`numpy.sqrt` See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html """ @@ -283,7 +283,7 @@ class Backend(): r""" First tensor elements raised to powers from second tensor, element-wise. - This function follow the api from :any:`numpy.power` + This function follows the api from :any:`numpy.power` See: https://numpy.org/doc/stable/reference/generated/numpy.power.html """ @@ -293,7 +293,7 @@ class Backend(): r""" Computes the matrix frobenius norm. - This function follow the api from :any:`numpy.linalg.norm` + This function follows the api from :any:`numpy.linalg.norm` See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html """ @@ -303,7 +303,7 @@ class Backend(): r""" Tests whether any tensor element along given dimensions evaluates to True. - This function follow the api from :any:`numpy.any` + This function follows the api from :any:`numpy.any` See: https://numpy.org/doc/stable/reference/generated/numpy.any.html """ @@ -313,7 +313,7 @@ class Backend(): r""" Tests element-wise for NaN and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isnan` + This function follows the api from :any:`numpy.isnan` See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html """ @@ -323,7 +323,7 @@ class Backend(): r""" Tests element-wise for positive or negative infinity and returns result as a boolean tensor. - This function follow the api from :any:`numpy.isinf` + This function follows the api from :any:`numpy.isinf` See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html """ @@ -333,7 +333,7 @@ class Backend(): r""" Evaluates the Einstein summation convention on the operands. - This function follow the api from :any:`numpy.einsum` + This function follows the api from :any:`numpy.einsum` See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html """ @@ -343,7 +343,7 @@ class Backend(): r""" Returns a sorted copy of a tensor. - This function follow the api from :any:`numpy.sort` + This function follows the api from :any:`numpy.sort` See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html """ @@ -353,7 +353,7 @@ class Backend(): r""" Returns the indices that would sort a tensor. - This function follow the api from :any:`numpy.argsort` + This function follows the api from :any:`numpy.argsort` See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html """ @@ -363,7 +363,7 @@ class Backend(): r""" Finds indices where elements should be inserted to maintain order in given tensor. - This function follow the api from :any:`numpy.searchsorted` + This function follows the api from :any:`numpy.searchsorted` See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html """ @@ -373,7 +373,7 @@ class Backend(): r""" Reverses the order of elements in a tensor along given dimensions. - This function follow the api from :any:`numpy.flip` + This function follows the api from :any:`numpy.flip` See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html """ @@ -383,7 +383,7 @@ class Backend(): """ Limits the values in a tensor. - This function follow the api from :any:`numpy.clip` + This function follows the api from :any:`numpy.clip` See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html """ @@ -393,7 +393,7 @@ class Backend(): r""" Repeats elements of a tensor. - This function follow the api from :any:`numpy.repeat` + This function follows the api from :any:`numpy.repeat` See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html """ @@ -403,7 +403,7 @@ class Backend(): r""" Gathers elements of a tensor along given dimensions. - This function follow the api from :any:`numpy.take_along_axis` + This function follows the api from :any:`numpy.take_along_axis` See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html """ @@ -413,7 +413,7 @@ class Backend(): r""" Joins a sequence of tensors along an existing dimension. - This function follow the api from :any:`numpy.concatenate` + This function follows the api from :any:`numpy.concatenate` See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html """ @@ -423,7 +423,7 @@ class Backend(): r""" Pads a tensor. - This function follow the api from :any:`numpy.pad` + This function follows the api from :any:`numpy.pad` See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html """ @@ -433,7 +433,7 @@ class Backend(): r""" Returns the indices of the maximum values of a tensor along given dimensions. - This function follow the api from :any:`numpy.argmax` + This function follows the api from :any:`numpy.argmax` See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html """ @@ -443,7 +443,7 @@ class Backend(): r""" Computes the arithmetic mean of a tensor along given dimensions. - This function follow the api from :any:`numpy.mean` + This function follows the api from :any:`numpy.mean` See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html """ @@ -453,7 +453,7 @@ class Backend(): r""" Computes the standard deviation of a tensor along given dimensions. - This function follow the api from :any:`numpy.std` + This function follows the api from :any:`numpy.std` See: https://numpy.org/doc/stable/reference/generated/numpy.std.html """ @@ -463,7 +463,7 @@ class Backend(): r""" Returns a specified number of evenly spaced values over a given interval. - This function follow the api from :any:`numpy.linspace` + This function follows the api from :any:`numpy.linspace` See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html """ @@ -473,7 +473,7 @@ class Backend(): r""" Returns coordinate matrices from coordinate vectors (Numpy convention). - This function follow the api from :any:`numpy.meshgrid` + This function follows the api from :any:`numpy.meshgrid` See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html """ @@ -483,7 +483,7 @@ class Backend(): r""" Extracts or constructs a diagonal tensor. - This function follow the api from :any:`numpy.diag` + This function follows the api from :any:`numpy.diag` See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html """ @@ -493,7 +493,7 @@ class Backend(): r""" Finds unique elements of given tensor. - This function follow the api from :any:`numpy.unique` + This function follows the api from :any:`numpy.unique` See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html """ @@ -503,7 +503,7 @@ class Backend(): r""" Computes the log of the sum of exponentials of input elements. - This function follow the api from :any:`scipy.special.logsumexp` + This function follows the api from :any:`scipy.special.logsumexp` See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html """ @@ -513,12 +513,32 @@ class Backend(): r""" Joins a sequence of tensors along a new dimension. - This function follow the api from :any:`numpy.stack` + This function follows the api from :any:`numpy.stack` See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html """ raise NotImplementedError() + def outer(self, a, b): + r""" + Computes the outer product between two vectors. + + This function follows the api from :any:`numpy.outer` + + See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html + """ + raise NotImplementedError() + + def reshape(self, a, shape): + r""" + Gives a new shape to a tensor without changing its data. + + This function follows the api from :any:`numpy.reshape` + + See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -644,6 +664,9 @@ class NumpyBackend(Backend): def flip(self, a, axis=None): return np.flip(a, axis) + def outer(self, a, b): + return np.outer(a, b) + def clip(self, a, a_min, a_max): return np.clip(a, a_min, a_max) @@ -686,6 +709,9 @@ class NumpyBackend(Backend): def stack(self, arrays, axis=0): return np.stack(arrays, axis) + def reshape(self, a, shape): + return np.reshape(a, shape) + class JaxBackend(Backend): """ @@ -815,6 +841,9 @@ class JaxBackend(Backend): def flip(self, a, axis=None): return jnp.flip(a, axis) + def outer(self, a, b): + return jnp.outer(a, b) + def clip(self, a, a_min, a_max): return jnp.clip(a, a_min, a_max) @@ -857,6 +886,9 @@ class JaxBackend(Backend): def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) + def reshape(self, a, shape): + return jnp.reshape(a, shape) + class TorchBackend(Backend): """ @@ -1035,6 +1067,9 @@ class TorchBackend(Backend): else: return torch.flip(a, dims=axis) + def outer(self, a, b): + return torch.outer(a, b) + def clip(self, a, a_min, a_max): return torch.clamp(a, a_min, a_max) @@ -1091,3 +1126,6 @@ class TorchBackend(Backend): def stack(self, arrays, axis=0): return torch.stack(arrays, dim=axis) + + def reshape(self, a, shape): + return torch.reshape(a, shape) |