summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-05 15:57:08 +0100
committerGitHub <noreply@github.com>2021-11-05 15:57:08 +0100
commit0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch)
treeb0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /ot/backend.py
parent0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff)
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py59
1 files changed, 52 insertions, 7 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 55e10d3..a044f84 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -653,6 +653,18 @@ class Backend():
"""
raise NotImplementedError()
+ def dtype_device(self, a):
+ r"""
+ Returns the dtype and the device of the given tensor.
+ """
+ raise NotImplementedError()
+
+ def assert_same_dtype_device(self, a, b):
+ r"""
+ Checks whether or not the two given inputs have the same dtype as well as the same device
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -880,6 +892,16 @@ class NumpyBackend(Backend):
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+ def dtype_device(self, a):
+ if hasattr(a, "dtype"):
+ return a.dtype, "cpu"
+ else:
+ return type(a), "cpu"
+
+ def assert_same_dtype_device(self, a, b):
+ # numpy has implicit type conversion so we automatically validate the test
+ pass
+
class JaxBackend(Backend):
"""
@@ -899,17 +921,20 @@ class JaxBackend(Backend):
self.rng_ = jax.random.PRNGKey(42)
for d in jax.devices():
- self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d),
- jax.device_put(jnp.array(1, dtype=np.float64), d)]
+ self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d),
+ jax.device_put(jnp.array(1, dtype=jnp.float64), d)]
def to_numpy(self, a):
return np.array(a)
+ def _change_device(self, a, type_as):
+ return jax.device_put(a, type_as.device_buffer.device())
+
def from_numpy(self, a, type_as=None):
if type_as is None:
return jnp.array(a)
else:
- return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device())
+ return self._change_device(jnp.array(a).astype(type_as.dtype), type_as)
def set_gradients(self, val, inputs, grads):
from jax.flatten_util import ravel_pytree
@@ -928,13 +953,13 @@ class JaxBackend(Backend):
if type_as is None:
return jnp.zeros(shape)
else:
- return jnp.zeros(shape, dtype=type_as.dtype)
+ return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as)
def ones(self, shape, type_as=None):
if type_as is None:
return jnp.ones(shape)
else:
- return jnp.ones(shape, dtype=type_as.dtype)
+ return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as)
def arange(self, stop, start=0, step=1, type_as=None):
return jnp.arange(start, stop, step)
@@ -943,13 +968,13 @@ class JaxBackend(Backend):
if type_as is None:
return jnp.full(shape, fill_value)
else:
- return jnp.full(shape, fill_value, dtype=type_as.dtype)
+ return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as)
def eye(self, N, M=None, type_as=None):
if type_as is None:
return jnp.eye(N, M)
else:
- return jnp.eye(N, M, dtype=type_as.dtype)
+ return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as)
def sum(self, a, axis=None, keepdims=False):
return jnp.sum(a, axis, keepdims=keepdims)
@@ -1127,6 +1152,16 @@ class JaxBackend(Backend):
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+ def dtype_device(self, a):
+ return a.dtype, a.device_buffer.device()
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
class TorchBackend(Backend):
"""
@@ -1455,3 +1490,13 @@ class TorchBackend(Backend):
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ return a.dtype, a.device
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"