summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py304
1 files changed, 273 insertions, 31 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 6e0bc3d..361ffba 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -87,7 +87,9 @@ Performance
# License: MIT License
import numpy as np
-import scipy.special as scipy
+import scipy
+import scipy.linalg
+import scipy.special as special
from scipy.sparse import issparse, coo_matrix, csr_matrix
import warnings
import time
@@ -102,7 +104,7 @@ except ImportError:
try:
import jax
import jax.numpy as jnp
- import jax.scipy.special as jscipy
+ import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
@@ -202,13 +204,29 @@ class Backend():
def __str__(self):
return self.__name__
- # convert to numpy
- def to_numpy(self, a):
+ # convert batch of tensors to numpy
+ def to_numpy(self, *arrays):
+ """Returns the numpy version of tensors"""
+ if len(arrays) == 1:
+ return self._to_numpy(arrays[0])
+ else:
+ return [self._to_numpy(array) for array in arrays]
+
+ # convert a tensor to numpy
+ def _to_numpy(self, a):
"""Returns the numpy version of a tensor"""
raise NotImplementedError()
- # convert from numpy
- def from_numpy(self, a, type_as=None):
+ # convert batch of arrays from numpy
+ def from_numpy(self, *arrays, type_as=None):
+ """Creates tensors cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
+ if len(arrays) == 1:
+ return self._from_numpy(arrays[0], type_as=type_as)
+ else:
+ return [self._from_numpy(array, type_as=type_as) for array in arrays]
+
+ # convert an array from numpy
+ def _from_numpy(self, a, type_as=None):
"""Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
raise NotImplementedError()
@@ -536,6 +554,16 @@ class Backend():
"""
raise NotImplementedError()
+ def argmin(self, a, axis=None):
+ r"""
+ Returns the indices of the minimum values of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.argmin`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html
+ """
+ raise NotImplementedError()
+
def mean(self, a, axis=None):
r"""
Computes the arithmetic mean of a tensor along given dimensions.
@@ -786,6 +814,72 @@ class Backend():
"""
raise NotImplementedError()
+ def solve(self, a, b):
+ r"""
+ Solves a linear matrix equation, or system of linear scalar equations.
+
+ This function follows the api from :any:`numpy.linalg.solve`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html
+ """
+ raise NotImplementedError()
+
+ def trace(self, a):
+ r"""
+ Returns the sum along diagonals of the array.
+
+ This function follows the api from :any:`numpy.trace`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html
+ """
+ raise NotImplementedError()
+
+ def inv(self, a):
+ r"""
+ Computes the inverse of a matrix.
+
+ This function follows the api from :any:`scipy.linalg.inv`.
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html
+ """
+ raise NotImplementedError()
+
+ def sqrtm(self, a):
+ r"""
+ Computes the matrix square root. Requires input to be definite positive.
+
+ This function follows the api from :any:`scipy.linalg.sqrtm`.
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html
+ """
+ raise NotImplementedError()
+
+ def isfinite(self, a):
+ r"""
+ Tests element-wise for finiteness (not infinity and not Not a Number).
+
+ This function follows the api from :any:`numpy.isfinite`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html
+ """
+ raise NotImplementedError()
+
+ def array_equal(self, a, b):
+ r"""
+ True if two arrays have the same shape and elements, False otherwise.
+
+ This function follows the api from :any:`numpy.array_equal`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html
+ """
+ raise NotImplementedError()
+
+ def is_floating_point(self, a):
+ r"""
+ Returns whether or not the input consists of floats
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -802,10 +896,10 @@ class NumpyBackend(Backend):
rng_ = np.random.RandomState()
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
if type_as is None:
return a
elif isinstance(a, float):
@@ -936,6 +1030,9 @@ class NumpyBackend(Backend):
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return np.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return np.mean(a, axis=axis)
@@ -955,7 +1052,7 @@ class NumpyBackend(Backend):
return np.unique(a)
def logsumexp(self, a, axis=None):
- return scipy.logsumexp(a, axis=axis)
+ return special.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return np.stack(arrays, axis)
@@ -1004,8 +1101,11 @@ class NumpyBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return np.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return np.where(condition)
+ else:
+ return np.where(condition, x, y)
def copy(self, a):
return a.copy()
@@ -1046,6 +1146,27 @@ class NumpyBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+ def solve(self, a, b):
+ return np.linalg.solve(a, b)
+
+ def trace(self, a):
+ return np.trace(a)
+
+ def inv(self, a):
+ return scipy.linalg.inv(a)
+
+ def sqrtm(self, a):
+ return scipy.linalg.sqrtm(a)
+
+ def isfinite(self, a):
+ return np.isfinite(a)
+
+ def array_equal(self, a, b):
+ return np.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class JaxBackend(Backend):
"""
@@ -1075,13 +1196,15 @@ class JaxBackend(Backend):
jax.device_put(jnp.array(1, dtype=jnp.float64), d)
]
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return np.array(a)
def _change_device(self, a, type_as):
return jax.device_put(a, type_as.device_buffer.device())
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return jnp.array(a)
else:
@@ -1216,6 +1339,9 @@ class JaxBackend(Backend):
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return jnp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return jnp.mean(a, axis=axis)
@@ -1235,7 +1361,7 @@ class JaxBackend(Backend):
return jnp.unique(a)
def logsumexp(self, a, axis=None):
- return jscipy.logsumexp(a, axis=axis)
+ return jspecial.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return jnp.stack(arrays, axis)
@@ -1293,8 +1419,11 @@ class JaxBackend(Backend):
# Currently, JAX does not support sparse matrices
return a
- def where(self, condition, x, y):
- return jnp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return jnp.where(condition)
+ else:
+ return jnp.where(condition, x, y)
def copy(self, a):
# No need to copy, JAX arrays are immutable
@@ -1339,6 +1468,28 @@ class JaxBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+ def solve(self, a, b):
+ return jnp.linalg.solve(a, b)
+
+ def trace(self, a):
+ return jnp.trace(a)
+
+ def inv(self, a):
+ return jnp.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = jnp.linalg.eigh(a)
+ return (V * jnp.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return jnp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return jnp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class TorchBackend(Backend):
"""
@@ -1384,10 +1535,10 @@ class TorchBackend(Backend):
self.ValFunction = ValFunction
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a.cpu().detach().numpy()
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if type_as is None:
@@ -1564,6 +1715,9 @@ class TorchBackend(Backend):
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
+ def argmin(self, a, axis=None):
+ return torch.argmin(a, dim=axis)
+
def mean(self, a, axis=None):
if axis is not None:
return torch.mean(a, dim=axis)
@@ -1580,8 +1734,11 @@ class TorchBackend(Backend):
return torch.linspace(start, stop, num, dtype=torch.float64)
def meshgrid(self, a, b):
- X, Y = torch.meshgrid(a, b)
- return X.T, Y.T
+ try:
+ return torch.meshgrid(a, b, indexing="xy")
+ except TypeError:
+ X, Y = torch.meshgrid(a, b)
+ return X.T, Y.T
def diag(self, a, k=0):
return torch.diag(a, diagonal=k)
@@ -1659,8 +1816,11 @@ class TorchBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return torch.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return torch.where(condition)
+ else:
+ return torch.where(condition, x, y)
def copy(self, a):
return torch.clone(a)
@@ -1718,6 +1878,28 @@ class TorchBackend(Backend):
torch.cuda.empty_cache()
return results
+ def solve(self, a, b):
+ return torch.linalg.solve(a, b)
+
+ def trace(self, a):
+ return torch.trace(a)
+
+ def inv(self, a):
+ return torch.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = torch.linalg.eigh(a)
+ return (V * torch.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return torch.isfinite(a)
+
+ def array_equal(self, a, b):
+ return torch.equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.is_floating_point
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -1741,10 +1923,12 @@ class CupyBackend(Backend): # pragma: no cover
cp.array(1, dtype=cp.float64)
]
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return cp.asnumpy(a)
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return cp.asarray(a)
else:
@@ -1884,6 +2068,9 @@ class CupyBackend(Backend): # pragma: no cover
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return cp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return cp.mean(a, axis=axis)
@@ -1982,8 +2169,11 @@ class CupyBackend(Backend): # pragma: no cover
else:
return a
- def where(self, condition, x, y):
- return cp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return cp.where(condition)
+ else:
+ return cp.where(condition, x, y)
def copy(self, a):
return a.copy()
@@ -2035,6 +2225,28 @@ class CupyBackend(Backend): # pragma: no cover
pinned_mempool.free_all_blocks()
return results
+ def solve(self, a, b):
+ return cp.linalg.solve(a, b)
+
+ def trace(self, a):
+ return cp.trace(a)
+
+ def inv(self, a):
+ return cp.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = cp.linalg.eigh(a)
+ return (V * self.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return cp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return cp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class TensorflowBackend(Backend):
@@ -2060,13 +2272,16 @@ class TensorflowBackend(Backend):
"To use TensorflowBackend, you need to activate the tensorflow "
"numpy API. You can activate it by running: \n"
"from tensorflow.python.ops.numpy_ops import np_config\n"
- "np_config.enable_numpy_behavior()"
+ "np_config.enable_numpy_behavior()",
+ stacklevel=2
)
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a.numpy()
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if not isinstance(a, self.__type__):
if type_as is None:
return tf.convert_to_tensor(a)
@@ -2208,6 +2423,9 @@ class TensorflowBackend(Backend):
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return tnp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return tnp.mean(a, axis=axis)
@@ -2309,8 +2527,11 @@ class TensorflowBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return tnp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return tnp.where(condition)
+ else:
+ return tnp.where(condition, x, y)
def copy(self, a):
return tf.identity(a)
@@ -2364,3 +2585,24 @@ class TensorflowBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+
+ def solve(self, a, b):
+ return tf.linalg.solve(a, b)
+
+ def trace(self, a):
+ return tf.linalg.trace(a)
+
+ def inv(self, a):
+ return tf.linalg.inv(a)
+
+ def sqrtm(self, a):
+ return tf.linalg.sqrtm(a)
+
+ def isfinite(self, a):
+ return tnp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return tnp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.is_floating