diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-11-05 15:57:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-05 15:57:08 +0100 |
commit | 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch) | |
tree | b0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /ot | |
parent | 0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff) |
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing
* bug solve
* Revert "bug solve"
This reverts commit 29b013abd162f8693128f26d8129186b79923609.
* Revert "First draft : making pytest use gpu for torch testing"
This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a.
* sliced
* sliced
* ot 1dsolver
* bregman
* better print
* jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them
* gromov & entropic gromov
Diffstat (limited to 'ot')
-rw-r--r-- | ot/backend.py | 59 | ||||
-rw-r--r-- | ot/sliced.py | 8 |
2 files changed, 56 insertions, 11 deletions
diff --git a/ot/backend.py b/ot/backend.py index 55e10d3..a044f84 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -653,6 +653,18 @@ class Backend(): """ raise NotImplementedError() + def dtype_device(self, a): + r""" + Returns the dtype and the device of the given tensor. + """ + raise NotImplementedError() + + def assert_same_dtype_device(self, a, b): + r""" + Checks whether or not the two given inputs have the same dtype as well as the same device + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -880,6 +892,16 @@ class NumpyBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + if hasattr(a, "dtype"): + return a.dtype, "cpu" + else: + return type(a), "cpu" + + def assert_same_dtype_device(self, a, b): + # numpy has implicit type conversion so we automatically validate the test + pass + class JaxBackend(Backend): """ @@ -899,17 +921,20 @@ class JaxBackend(Backend): self.rng_ = jax.random.PRNGKey(42) for d in jax.devices(): - self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), - jax.device_put(jnp.array(1, dtype=np.float64), d)] + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d)] def to_numpy(self, a): return np.array(a) + def _change_device(self, a, type_as): + return jax.device_put(a, type_as.device_buffer.device()) + def from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) + return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -928,13 +953,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.zeros(shape) else: - return jnp.zeros(shape, dtype=type_as.dtype) + return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as) def ones(self, shape, type_as=None): if type_as is None: return jnp.ones(shape) else: - return jnp.ones(shape, dtype=type_as.dtype) + return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as) def arange(self, stop, start=0, step=1, type_as=None): return jnp.arange(start, stop, step) @@ -943,13 +968,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.full(shape, fill_value) else: - return jnp.full(shape, fill_value, dtype=type_as.dtype) + return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as) def eye(self, N, M=None, type_as=None): if type_as is None: return jnp.eye(N, M) else: - return jnp.eye(N, M, dtype=type_as.dtype) + return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as) def sum(self, a, axis=None, keepdims=False): return jnp.sum(a, axis, keepdims=keepdims) @@ -1127,6 +1152,16 @@ class JaxBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + return a.dtype, a.device_buffer.device() + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + class TorchBackend(Backend): """ @@ -1455,3 +1490,13 @@ class TorchBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" diff --git a/ot/sliced.py b/ot/sliced.py index 7c09111..cf2d3be 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -139,9 +139,9 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] @@ -238,9 +238,9 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] |