summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py876
1 files changed, 870 insertions, 6 deletions
diff --git a/ot/backend.py b/ot/backend.py
index a044f84..58b652b 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -3,7 +3,7 @@
Multi-lib backend for POT
The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
-or Jax, POT code should work nonetheless.
+Jax, Cupy, or Tensorflow, POT code should work nonetheless.
To achieve that, POT provides backend classes which implements functions in their respective backend
imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
@@ -17,6 +17,68 @@ Examples
... nx = get_backend(a, b) # infer the backend from the arguments
... c = nx.dot(a, b) # now use the backend to do any calculation
... return c
+
+.. warning::
+ Tensorflow only works with the Numpy API. To activate it, please run the following:
+
+ .. code-block::
+
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
+Performance
+--------
+
+- CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
+- GPU: Tesla V100-SXM2-32GB
+- Date of the benchmark: December 8th, 2021
+- Commit of benchmark: PR #316, https://github.com/PythonOT/POT/pull/316
+
+.. raw:: html
+
+ <style>
+ #perftable {
+ width: 100%;
+ margin-bottom: 1em;
+ }
+
+ #perftable table{
+ border-collapse: collapse;
+ table-layout: fixed;
+ width: 100%;
+ }
+
+ #perftable th, #perftable td {
+ border: 1px solid #ddd;
+ padding: 8px;
+ font-size: smaller;
+ }
+ </style>
+
+ <div id="perftable">
+ <table>
+ <tr><th align="center" colspan="8">Sinkhorn Knopp - Averaged on 100 runs</th></tr>
+ <tr><th align="center">Bitsize</th><th align="center" colspan="7">32 bits</th></tr>
+ <tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr>
+ <tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr>
+ <tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0022</td><td align="center">0.0151</td><td align="center">0.0095</td><td align="center">0.0193</td><td align="center">0.0051</td><td align="center">0.0293</td></tr>
+ <tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0097</td><td align="center">0.0057</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0173</td></tr>
+ <tr><td align="center">500</td><td align="center">0.0009</td><td align="center">0.0016</td><td align="center">0.0110</td><td align="center">0.0058</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0166</td></tr>
+ <tr><td align="center">1000</td><td align="center">0.0021</td><td align="center">0.0021</td><td align="center">0.0145</td><td align="center">0.0056</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0168</td></tr>
+ <tr><td align="center">2000</td><td align="center">0.0069</td><td align="center">0.0043</td><td align="center">0.0278</td><td align="center">0.0059</td><td align="center">0.0118</td><td align="center">0.0030</td><td align="center">0.0165</td></tr>
+ <tr><td align="center">5000</td><td align="center">0.0707</td><td align="center">0.0314</td><td align="center">0.1395</td><td align="center">0.0074</td><td align="center">0.0125</td><td align="center">0.0035</td><td align="center">0.0198</td></tr>
+ <tr><td colspan="8">&nbsp;</td></tr>
+ <tr><th align="center">Bitsize</th><th align="center" colspan="7">64 bits</th></tr>
+ <tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr>
+ <tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr>
+ <tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0020</td><td align="center">0.0154</td><td align="center">0.0093</td><td align="center">0.0191</td><td align="center">0.0051</td><td align="center">0.0328</td></tr>
+ <tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0094</td><td align="center">0.0056</td><td align="center">0.0114</td><td align="center">0.0029</td><td align="center">0.0169</td></tr>
+ <tr><td align="center">500</td><td align="center">0.0013</td><td align="center">0.0017</td><td align="center">0.0120</td><td align="center">0.0059</td><td align="center">0.0116</td><td align="center">0.0029</td><td align="center">0.0168</td></tr>
+ <tr><td align="center">1000</td><td align="center">0.0034</td><td align="center">0.0027</td><td align="center">0.0177</td><td align="center">0.0058</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0167</td></tr>
+ <tr><td align="center">2000</td><td align="center">0.0146</td><td align="center">0.0075</td><td align="center">0.0436</td><td align="center">0.0059</td><td align="center">0.0120</td><td align="center">0.0029</td><td align="center">0.0165</td></tr>
+ <tr><td align="center">5000</td><td align="center">0.1467</td><td align="center">0.0568</td><td align="center">0.2468</td><td align="center">0.0077</td><td align="center">0.0146</td><td align="center">0.0045</td><td align="center">0.0204</td></tr>
+ </table>
+ </div>
"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
@@ -27,6 +89,8 @@ Examples
import numpy as np
import scipy.special as scipy
from scipy.sparse import issparse, coo_matrix, csr_matrix
+import warnings
+import time
try:
import torch
@@ -39,11 +103,29 @@ try:
import jax
import jax.numpy as jnp
import jax.scipy.special as jscipy
+ from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
jax = False
jax_type = float
+try:
+ import cupy as cp
+ import cupyx
+ cp_type = cp.ndarray
+except ImportError:
+ cp = False
+ cp_type = float
+
+try:
+ import tensorflow as tf
+ import tensorflow.experimental.numpy as tnp
+ tf_type = tf.Tensor
+except ImportError:
+ tf = False
+ tf_type = float
+
+
str_type_error = "All array should be from the same type/backend. Current types are : {}"
@@ -57,6 +139,12 @@ def get_backend_list():
if jax:
lst.append(JaxBackend())
+ if cp: # pragma: no cover
+ lst.append(CupyBackend())
+
+ if tf:
+ lst.append(TensorflowBackend())
+
return lst
@@ -78,6 +166,10 @@ def get_backend(*args):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
+ elif isinstance(args[0], cp_type): # pragma: no cover
+ return CupyBackend()
+ elif isinstance(args[0], tf_type):
+ return TensorflowBackend()
else:
raise ValueError("Unknown type of non implemented backend.")
@@ -94,7 +186,8 @@ def to_numpy(*args):
class Backend():
"""
Backend abstract class.
- Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
+ :py:class:`CupyBackend`, :py:class:`TensorflowBackend`
- The `__name__` class attribute refers to the name of the backend.
- The `__type__` class attribute refers to the data structure used by the backend.
@@ -665,6 +758,34 @@ class Backend():
"""
raise NotImplementedError()
+ def squeeze(self, a, axis=None):
+ r"""
+ Remove axes of length one from a.
+
+ This function follows the api from :any:`numpy.squeeze`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html
+ """
+ raise NotImplementedError()
+
+ def bitsize(self, type_as):
+ r"""
+ Gives the number of bits used by the data type of the given tensor.
+ """
+ raise NotImplementedError()
+
+ def device_type(self, type_as):
+ r"""
+ Returns CPU or GPU depending on the device where the given tensor is located.
+ """
+ raise NotImplementedError()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ r"""
+ Executes a benchmark of the given callable with the given arguments.
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -902,6 +1023,29 @@ class NumpyBackend(Backend):
# numpy has implicit type conversion so we automatically validate the test
pass
+ def squeeze(self, a, axis=None):
+ return np.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.itemsize * 8
+
+ def device_type(self, type_as):
+ return "CPU"
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ callable(*inputs)
+ t1 = time.perf_counter()
+ key = ("Numpy", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = (t1 - t0) / n_runs
+ return results
+
class JaxBackend(Backend):
"""
@@ -920,9 +1064,16 @@ class JaxBackend(Backend):
def __init__(self):
self.rng_ = jax.random.PRNGKey(42)
- for d in jax.devices():
- self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d),
- jax.device_put(jnp.array(1, dtype=jnp.float64), d)]
+ self.__type_list__ = []
+ # available_devices = jax.devices("cpu")
+ available_devices = []
+ if xla_bridge.get_backend().platform == "gpu":
+ available_devices += jax.devices("gpu")
+ for d in available_devices:
+ self.__type_list__ += [
+ jax.device_put(jnp.array(1, dtype=jnp.float32), d),
+ jax.device_put(jnp.array(1, dtype=jnp.float64), d)
+ ]
def to_numpy(self, a):
return np.array(a)
@@ -1162,6 +1313,32 @@ class JaxBackend(Backend):
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+ def squeeze(self, a, axis=None):
+ return jnp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.dtype.itemsize * 8
+
+ def device_type(self, type_as):
+ return self.dtype_device(type_as)[1].platform.upper()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ a = callable(*inputs)
+ a.block_until_ready()
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ a = callable(*inputs)
+ a.block_until_ready()
+ t1 = time.perf_counter()
+ key = ("Jax", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = (t1 - t0) / n_runs
+ return results
+
class TorchBackend(Backend):
"""
@@ -1203,7 +1380,7 @@ class TorchBackend(Backend):
@staticmethod
def backward(ctx, grad_output):
# the gradients are grad
- return (None, None) + ctx.grads
+ return (None, None) + tuple(g * grad_output for g in ctx.grads)
self.ValFunction = ValFunction
@@ -1500,3 +1677,690 @@ class TorchBackend(Backend):
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+ def squeeze(self, a, axis=None):
+ if axis is None:
+ return torch.squeeze(a)
+ else:
+ return torch.squeeze(a, dim=axis)
+
+ def bitsize(self, type_as):
+ return torch.finfo(type_as.dtype).bits
+
+ def device_type(self, type_as):
+ return type_as.device.type.replace("cuda", "gpu").upper()
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ if self.device_type(type_as) == "GPU": # pragma: no cover
+ torch.cuda.synchronize()
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ else:
+ start = time.perf_counter()
+ for _ in range(n_runs):
+ callable(*inputs)
+ if self.device_type(type_as) == "GPU": # pragma: no cover
+ end.record()
+ torch.cuda.synchronize()
+ duration = start.elapsed_time(end) / 1000.
+ else:
+ end = time.perf_counter()
+ duration = end - start
+ key = ("Pytorch", self.device_type(type_as), self.bitsize(type_as))
+ results[key] = duration / n_runs
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return results
+
+
+class CupyBackend(Backend): # pragma: no cover
+ """
+ CuPy implementation of the backend
+
+ - `__name__` is "cupy"
+ - `__type__` is cp.ndarray
+ """
+
+ __name__ = 'cupy'
+ __type__ = cp_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+ self.rng_ = cp.random.RandomState()
+
+ self.__type_list__ = [
+ cp.array(1, dtype=cp.float32),
+ cp.array(1, dtype=cp.float64)
+ ]
+
+ def to_numpy(self, a):
+ return cp.asnumpy(a)
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return cp.asarray(a)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.asarray(a, dtype=type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # No gradients for cupy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.zeros(shape)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.ones(shape)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return cp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.full(shape, fill_value)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return cp.eye(N, M)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return cp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return cp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return cp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return cp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return cp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return cp.minimum(a, b)
+
+ def abs(self, a):
+ return cp.abs(a)
+
+ def exp(self, a):
+ return cp.exp(a)
+
+ def log(self, a):
+ return cp.log(a)
+
+ def sqrt(self, a):
+ return cp.sqrt(a)
+
+ def power(self, a, exponents):
+ return cp.power(a, exponents)
+
+ def dot(self, a, b):
+ return cp.dot(a, b)
+
+ def norm(self, a):
+ return cp.sqrt(cp.sum(cp.square(a)))
+
+ def any(self, a):
+ return cp.any(a)
+
+ def isnan(self, a):
+ return cp.isnan(a)
+
+ def isinf(self, a):
+ return cp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return cp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return cp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return cp.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return cp.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make numpy
+ # searchsorted work on 2d arrays
+ ret = cp.empty(v.shape, dtype=int)
+ for i in range(a.shape[0]):
+ ret[i, :] = cp.searchsorted(a[i, :], v[i, :], side)
+ return ret
+
+ def flip(self, a, axis=None):
+ return cp.flip(a, axis)
+
+ def outer(self, a, b):
+ return cp.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return cp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return cp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return cp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return cp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return cp.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return cp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return cp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return cp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return cp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return cp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return cp.diag(a, k)
+
+ def unique(self, a):
+ return cp.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ # Taken from
+ # https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127
+ a_max = cp.amax(a, axis=axis, keepdims=True)
+
+ if a_max.ndim > 0:
+ a_max[~cp.isfinite(a_max)] = 0
+ elif not cp.isfinite(a_max):
+ a_max = 0
+
+ tmp = cp.exp(a - a_max)
+ s = cp.sum(tmp, axis=axis)
+ out = cp.log(s)
+ a_max = cp.squeeze(a_max, axis=axis)
+ out += a_max
+ return out
+
+ def stack(self, arrays, axis=0):
+ return cp.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return cp.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_.seed(seed)
+
+ def rand(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.rand(*size)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return self.rng_.rand(*size, dtype=type_as.dtype)
+
+ def randn(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.randn(*size)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return self.rng_.randn(*size, dtype=type_as.dtype)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ data = self.from_numpy(data)
+ rows = self.from_numpy(rows)
+ cols = self.from_numpy(cols)
+ if type_as is None:
+ return cupyx.scipy.sparse.coo_matrix(
+ (data, (rows, cols)), shape=shape
+ )
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cupyx.scipy.sparse.coo_matrix(
+ (data, (rows, cols)), shape=shape, dtype=type_as.dtype
+ )
+
+ def issparse(self, a):
+ return cupyx.scipy.sparse.issparse(a)
+
+ def tocsr(self, a):
+ if self.issparse(a):
+ return a.tocsr()
+ else:
+ return cupyx.scipy.sparse.csr_matrix(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if threshold > 0:
+ if self.issparse(a):
+ a.data[self.abs(a.data) <= threshold] = 0
+ else:
+ a[self.abs(a) <= threshold] = 0
+ if self.issparse(a):
+ a.eliminate_zeros()
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.toarray()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return cp.where(condition, x, y)
+
+ def copy(self, a):
+ return a.copy()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return cp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ return a.dtype, a.device
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ # cupy has implicit type conversion so
+ # we automatically validate the test for type
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+ def squeeze(self, a, axis=None):
+ return cp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.itemsize * 8
+
+ def device_type(self, type_as):
+ return "GPU"
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ mempool = cp.get_default_memory_pool()
+ pinned_mempool = cp.get_default_pinned_memory_pool()
+
+ results = dict()
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ start_gpu = cp.cuda.Event()
+ end_gpu = cp.cuda.Event()
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ start_gpu.synchronize()
+ start_gpu.record()
+ for _ in range(n_runs):
+ callable(*inputs)
+ end_gpu.record()
+ end_gpu.synchronize()
+ key = ("Cupy", self.device_type(type_as), self.bitsize(type_as))
+ t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000.
+ results[key] = t_gpu / n_runs
+ mempool.free_all_blocks()
+ pinned_mempool.free_all_blocks()
+ return results
+
+
+class TensorflowBackend(Backend):
+
+ __name__ = "tf"
+ __type__ = tf_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+ self.seed(None)
+
+ self.__type_list__ = [
+ tf.convert_to_tensor([1], dtype=tf.float32),
+ tf.convert_to_tensor([1], dtype=tf.float64)
+ ]
+
+ tmp = self.randn(15, 10)
+ try:
+ tmp.reshape((150, 1))
+ except AttributeError:
+ warnings.warn(
+ "To use TensorflowBackend, you need to activate the tensorflow "
+ "numpy API. You can activate it by running: \n"
+ "from tensorflow.python.ops.numpy_ops import np_config\n"
+ "np_config.enable_numpy_behavior()"
+ )
+
+ def to_numpy(self, a):
+ return a.numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if not isinstance(a, self.__type__):
+ if type_as is None:
+ return tf.convert_to_tensor(a)
+ else:
+ return tf.convert_to_tensor(a, dtype=type_as.dtype)
+ else:
+ if type_as is None:
+ return a
+ else:
+ return tf.cast(a, dtype=type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ @tf.custom_gradient
+ def tmp(input):
+ def grad(upstream):
+ return grads
+ return val, grad
+ return tmp(inputs)
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return tnp.zeros(shape)
+ else:
+ return tnp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return tnp.ones(shape)
+ else:
+ return tnp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return tnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return tnp.full(shape, fill_value)
+ else:
+ return tnp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return tnp.eye(N, M)
+ else:
+ return tnp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return tnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return tnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return tnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return tnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return tnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return tnp.minimum(a, b)
+
+ def dot(self, a, b):
+ if len(b.shape) == 1:
+ if len(a.shape) == 1:
+ # inner product
+ return tf.reduce_sum(tf.multiply(a, b))
+ else:
+ # matrix vector
+ return tf.linalg.matvec(a, b)
+ else:
+ if len(a.shape) == 1:
+ return tf.linalg.matvec(b.T, a.T).T
+ else:
+ return tf.matmul(a, b)
+
+ def abs(self, a):
+ return tnp.abs(a)
+
+ def exp(self, a):
+ return tnp.exp(a)
+
+ def log(self, a):
+ return tnp.log(a)
+
+ def sqrt(self, a):
+ return tnp.sqrt(a)
+
+ def power(self, a, exponents):
+ return tnp.power(a, exponents)
+
+ def norm(self, a):
+ return tf.math.reduce_euclidean_norm(a)
+
+ def any(self, a):
+ return tnp.any(a)
+
+ def isnan(self, a):
+ return tnp.isnan(a)
+
+ def isinf(self, a):
+ return tnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return tnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return tnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return tnp.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ return tf.searchsorted(a, v, side=side)
+
+ def flip(self, a, axis=None):
+ return tnp.flip(a, axis)
+
+ def outer(self, a, b):
+ return tnp.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return tnp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return tnp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return tnp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return tnp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return tnp.pad(a, pad_width, mode="constant")
+
+ def argmax(self, a, axis=None):
+ return tnp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return tnp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return tnp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return tnp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return tnp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return tnp.diag(a, k)
+
+ def unique(self, a):
+ return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
+
+ def logsumexp(self, a, axis=None):
+ return tf.math.reduce_logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return tnp.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return tnp.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if isinstance(seed, int):
+ self.rng_ = tf.random.Generator.from_seed(seed)
+ elif isinstance(seed, tf.random.Generator):
+ self.rng_ = seed
+ elif seed is None:
+ self.rng_ = tf.random.Generator.from_non_deterministic_state()
+ else:
+ raise ValueError("Non compatible seed : {}".format(seed))
+
+ def rand(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.uniform(size, minval=0., maxval=1.)
+ else:
+ return self.rng_.uniform(
+ size, minval=0., maxval=1., dtype=type_as.dtype
+ )
+
+ def randn(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.normal(size)
+ else:
+ return self.rng_.normal(size, dtype=type_as.dtype)
+
+ def _convert_to_index_for_coo(self, tensor):
+ if isinstance(tensor, self.__type__):
+ return int(self.max(tensor)) + 1
+ else:
+ return int(np.max(tensor)) + 1
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if shape is None:
+ shape = (
+ self._convert_to_index_for_coo(rows),
+ self._convert_to_index_for_coo(cols)
+ )
+ if type_as is not None:
+ data = self.from_numpy(data, type_as=type_as)
+
+ sparse_tensor = tf.sparse.SparseTensor(
+ indices=tnp.stack([rows, cols]).T,
+ values=data,
+ dense_shape=shape
+ )
+ # if type_as is not None:
+ # sparse_tensor = self.from_numpy(sparse_tensor, type_as=type_as)
+ # SparseTensor are not subscriptable so we use dense tensors
+ return self.todense(sparse_tensor)
+
+ def issparse(self, a):
+ return isinstance(a, tf.sparse.SparseTensor)
+
+ def tocsr(self, a):
+ return a
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if self.issparse(a):
+ values = a.values
+ if threshold > 0:
+ mask = self.abs(values) <= threshold
+ else:
+ mask = values == 0
+ return tf.sparse.retain(a, ~mask)
+ else:
+ if threshold > 0:
+ a = tnp.where(self.abs(a) > threshold, a, 0.)
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return tf.sparse.to_dense(tf.sparse.reorder(a))
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return tnp.where(condition, x, y)
+
+ def copy(self, a):
+ return tf.identity(a)
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return tnp.allclose(
+ a, b, rtol=rtol, atol=atol, equal_nan=equal_nan
+ )
+
+ def dtype_device(self, a):
+ return a.dtype, a.device.split("device:")[1]
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+ def squeeze(self, a, axis=None):
+ return tnp.squeeze(a, axis=axis)
+
+ def bitsize(self, type_as):
+ return type_as.dtype.size * 8
+
+ def device_type(self, type_as):
+ return self.dtype_device(type_as)[1].split(":")[0]
+
+ def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
+ results = dict()
+ device_contexts = [tf.device("/CPU:0")]
+ if len(tf.config.list_physical_devices('GPU')) > 0: # pragma: no cover
+ device_contexts.append(tf.device("/GPU:0"))
+
+ for device_context in device_contexts:
+ with device_context:
+ for type_as in self.__type_list__:
+ inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
+ for _ in range(warmup_runs):
+ callable(*inputs)
+ t0 = time.perf_counter()
+ for _ in range(n_runs):
+ res = callable(*inputs)
+ _ = res.numpy()
+ t1 = time.perf_counter()
+ key = (
+ "Tensorflow",
+ self.device_type(inputs[0]),
+ self.bitsize(type_as)
+ )
+ results[key] = (t1 - t0) / n_runs
+
+ return results