From 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 5 Nov 2021 15:57:08 +0100 Subject: [MRG] Tests with types/device on sliced/bregman/gromov functions (#303) * First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov --- ot/backend.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 7 deletions(-) (limited to 'ot/backend.py') diff --git a/ot/backend.py b/ot/backend.py index 55e10d3..a044f84 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -653,6 +653,18 @@ class Backend(): """ raise NotImplementedError() + def dtype_device(self, a): + r""" + Returns the dtype and the device of the given tensor. + """ + raise NotImplementedError() + + def assert_same_dtype_device(self, a, b): + r""" + Checks whether or not the two given inputs have the same dtype as well as the same device + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -880,6 +892,16 @@ class NumpyBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + if hasattr(a, "dtype"): + return a.dtype, "cpu" + else: + return type(a), "cpu" + + def assert_same_dtype_device(self, a, b): + # numpy has implicit type conversion so we automatically validate the test + pass + class JaxBackend(Backend): """ @@ -899,17 +921,20 @@ class JaxBackend(Backend): self.rng_ = jax.random.PRNGKey(42) for d in jax.devices(): - self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), - jax.device_put(jnp.array(1, dtype=np.float64), d)] + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d)] def to_numpy(self, a): return np.array(a) + def _change_device(self, a, type_as): + return jax.device_put(a, type_as.device_buffer.device()) + def from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) + return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -928,13 +953,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.zeros(shape) else: - return jnp.zeros(shape, dtype=type_as.dtype) + return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as) def ones(self, shape, type_as=None): if type_as is None: return jnp.ones(shape) else: - return jnp.ones(shape, dtype=type_as.dtype) + return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as) def arange(self, stop, start=0, step=1, type_as=None): return jnp.arange(start, stop, step) @@ -943,13 +968,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.full(shape, fill_value) else: - return jnp.full(shape, fill_value, dtype=type_as.dtype) + return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as) def eye(self, N, M=None, type_as=None): if type_as is None: return jnp.eye(N, M) else: - return jnp.eye(N, M, dtype=type_as.dtype) + return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as) def sum(self, a, axis=None, keepdims=False): return jnp.sum(a, axis, keepdims=keepdims) @@ -1127,6 +1152,16 @@ class JaxBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + return a.dtype, a.device_buffer.device() + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + class TorchBackend(Backend): """ @@ -1455,3 +1490,13 @@ class TorchBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" -- cgit v1.2.3