summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-02 13:42:02 +0100
committerGitHub <noreply@github.com>2021-11-02 13:42:02 +0100
commita335324d008e8982be61d7ace937815a2bfa98f9 (patch)
tree83c7f637597f10f6f3d20b15532e53fc65b51f22 /ot/backend.py
parent0cb2b2efe901ed74c614046d250518769f870313 (diff)
[MRG] Backend for gromov (#294)
* bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py214
1 files changed, 212 insertions, 2 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 876b96a..358297c 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -26,6 +26,7 @@ Examples
import numpy as np
import scipy.special as scipy
+from scipy.sparse import issparse, coo_matrix, csr_matrix
try:
import torch
@@ -539,6 +540,86 @@ class Backend():
"""
raise NotImplementedError()
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ r"""
+ Creates a sparse tensor in COOrdinate format.
+
+ This function follows the api from :any:`scipy.sparse.coo_matrix`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
+ """
+ raise NotImplementedError()
+
+ def issparse(self, a):
+ r"""
+ Checks whether or not the input tensor is a sparse tensor.
+
+ This function follows the api from :any:`scipy.sparse.issparse`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html
+ """
+ raise NotImplementedError()
+
+ def tocsr(self, a):
+ r"""
+ Converts this matrix to Compressed Sparse Row format.
+
+ This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html
+ """
+ raise NotImplementedError()
+
+ def eliminate_zeros(self, a, threshold=0.):
+ r"""
+ Removes entries smaller than the given threshold from the sparse tensor.
+
+ This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros`
+
+ See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html
+ """
+ raise NotImplementedError()
+
+ def todense(self, a):
+ r"""
+ Converts a sparse tensor to a dense tensor.
+
+ This function follows the api from :any:`scipy.sparse.csr_matrix.toarray`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html
+ """
+ raise NotImplementedError()
+
+ def where(self, condition, x, y):
+ r"""
+ Returns elements chosen from x or y depending on condition.
+
+ This function follows the api from :any:`numpy.where`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.where.html
+ """
+ raise NotImplementedError()
+
+ def copy(self, a):
+ r"""
+ Returns a copy of the given tensor.
+
+ This function follows the api from :any:`numpy.copy`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html
+ """
+ raise NotImplementedError()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ r"""
+ Returns True if two arrays are element-wise equal within a tolerance.
+
+ This function follows the api from :any:`numpy.allclose`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -712,6 +793,46 @@ class NumpyBackend(Backend):
def reshape(self, a, shape):
return np.reshape(a, shape)
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if type_as is None:
+ return coo_matrix((data, (rows, cols)), shape=shape)
+ else:
+ return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
+
+ def issparse(self, a):
+ return issparse(a)
+
+ def tocsr(self, a):
+ if self.issparse(a):
+ return a.tocsr()
+ else:
+ return csr_matrix(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if threshold > 0:
+ if self.issparse(a):
+ a.data[self.abs(a.data) <= threshold] = 0
+ else:
+ a[self.abs(a) <= threshold] = 0
+ if self.issparse(a):
+ a.eliminate_zeros()
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.toarray()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return np.where(condition, x, y)
+
+ def copy(self, a):
+ return a.copy()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
class JaxBackend(Backend):
"""
@@ -889,6 +1010,48 @@ class JaxBackend(Backend):
def reshape(self, a, shape):
return jnp.reshape(a, shape)
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ # Currently, JAX does not support sparse matrices
+ data = self.to_numpy(data)
+ rows = self.to_numpy(rows)
+ cols = self.to_numpy(cols)
+ nx = NumpyBackend()
+ coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as)
+ matrix = nx.todense(coo_matrix)
+ return self.from_numpy(matrix)
+
+ def issparse(self, a):
+ # Currently, JAX does not support sparse matrices
+ return False
+
+ def tocsr(self, a):
+ # Currently, JAX does not support sparse matrices
+ return a
+
+ def eliminate_zeros(self, a, threshold=0.):
+ # Currently, JAX does not support sparse matrices
+ if threshold > 0:
+ return self.where(
+ self.abs(a) <= threshold,
+ self.zeros((1,), type_as=a),
+ a
+ )
+ return a
+
+ def todense(self, a):
+ # Currently, JAX does not support sparse matrices
+ return a
+
+ def where(self, condition, x, y):
+ return jnp.where(condition, x, y)
+
+ def copy(self, a):
+ # No need to copy, JAX arrays are immutable
+ return a
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
class TorchBackend(Backend):
"""
@@ -999,7 +1162,7 @@ class TorchBackend(Backend):
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
if isinstance(b, int) or isinstance(b, float):
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
- if torch.__version__ >= '1.7.0':
+ if hasattr(torch, "maximum"):
return torch.maximum(a, b)
else:
return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
@@ -1009,7 +1172,7 @@ class TorchBackend(Backend):
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
if isinstance(b, int) or isinstance(b, float):
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
- if torch.__version__ >= '1.7.0':
+ if hasattr(torch, "minimum"):
return torch.minimum(a, b)
else:
return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
@@ -1129,3 +1292,50 @@ class TorchBackend(Backend):
def reshape(self, a, shape):
return torch.reshape(a, shape)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if type_as is None:
+ return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape)
+ else:
+ return torch.sparse_coo_tensor(
+ torch.stack([rows, cols]), data, size=shape,
+ dtype=type_as.dtype, device=type_as.device
+ )
+
+ def issparse(self, a):
+ return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False)
+
+ def tocsr(self, a):
+ # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support
+ return self.todense(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if self.issparse(a):
+ if threshold > 0:
+ mask = self.abs(a) <= threshold
+ mask = ~mask
+ mask = mask.nonzero()
+ else:
+ mask = a._values().nonzero()
+ nv = a._values().index_select(0, mask.view(-1))
+ ni = a._indices().index_select(1, mask.view(-1))
+ return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a)
+ else:
+ if threshold > 0:
+ a[self.abs(a) <= threshold] = 0
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.to_dense()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return torch.where(condition, x, y)
+
+ def copy(self, a):
+ return torch.clone(a)
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)