diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-11-04 15:19:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-04 15:19:57 +0100 |
commit | 0e431c203a66c6d48e6bb1efeda149460472a0f0 (patch) | |
tree | 22a447a1dbb1505b18f9e426e1761cf6b328b6eb /ot/backend.py | |
parent | 2fe69eb130827560ada704bc25998397c4357821 (diff) |
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu
* pep 8 of couse
* debug torch
* jax with gpu
* device put
* device put
* it works
* emd1d and emd2_1d working
* emd_1d and emd2_1d done
* cleanup
* of course
* should work on gpu now
* tests done+ pep8
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/ot/backend.py b/ot/backend.py index d3df44c..55e10d3 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -102,6 +102,7 @@ class Backend(): __name__ = None __type__ = None + __type_list__ = None rng_ = None @@ -663,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + __type_list__ = [np.array(1, dtype=np.float32), + np.array(1, dtype=np.float64)] rng_ = np.random.RandomState() @@ -888,12 +891,17 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + __type_list__ = None rng_ = None def __init__(self): self.rng_ = jax.random.PRNGKey(42) + for d in jax.devices(): + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), + jax.device_put(jnp.array(1, dtype=np.float64), d)] + def to_numpy(self, a): return np.array(a) @@ -901,7 +909,7 @@ class JaxBackend(Backend): if type_as is None: return jnp.array(a) else: - return jnp.array(a).astype(type_as.dtype) + return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1130,6 +1138,7 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + __type_list__ = None rng_ = None @@ -1138,6 +1147,13 @@ class TorchBackend(Backend): self.rng_ = torch.Generator() self.rng_.seed() + self.__type_list__ = [torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64)] + + if torch.cuda.is_available(): + self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1160,6 +1176,8 @@ class TorchBackend(Backend): return a.cpu().detach().numpy() def from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return torch.from_numpy(a) else: |