summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-04 15:19:57 +0100
committerGitHub <noreply@github.com>2021-11-04 15:19:57 +0100
commit0e431c203a66c6d48e6bb1efeda149460472a0f0 (patch)
tree22a447a1dbb1505b18f9e426e1761cf6b328b6eb /ot/backend.py
parent2fe69eb130827560ada704bc25998397c4357821 (diff)
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py20
1 files changed, 19 insertions, 1 deletions
diff --git a/ot/backend.py b/ot/backend.py
index d3df44c..55e10d3 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -102,6 +102,7 @@ class Backend():
__name__ = None
__type__ = None
+ __type_list__ = None
rng_ = None
@@ -663,6 +664,8 @@ class NumpyBackend(Backend):
__name__ = 'numpy'
__type__ = np.ndarray
+ __type_list__ = [np.array(1, dtype=np.float32),
+ np.array(1, dtype=np.float64)]
rng_ = np.random.RandomState()
@@ -888,12 +891,17 @@ class JaxBackend(Backend):
__name__ = 'jax'
__type__ = jax_type
+ __type_list__ = None
rng_ = None
def __init__(self):
self.rng_ = jax.random.PRNGKey(42)
+ for d in jax.devices():
+ self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d),
+ jax.device_put(jnp.array(1, dtype=np.float64), d)]
+
def to_numpy(self, a):
return np.array(a)
@@ -901,7 +909,7 @@ class JaxBackend(Backend):
if type_as is None:
return jnp.array(a)
else:
- return jnp.array(a).astype(type_as.dtype)
+ return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device())
def set_gradients(self, val, inputs, grads):
from jax.flatten_util import ravel_pytree
@@ -1130,6 +1138,7 @@ class TorchBackend(Backend):
__name__ = 'torch'
__type__ = torch_type
+ __type_list__ = None
rng_ = None
@@ -1138,6 +1147,13 @@ class TorchBackend(Backend):
self.rng_ = torch.Generator()
self.rng_.seed()
+ self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
+ torch.tensor(1, dtype=torch.float64)]
+
+ if torch.cuda.is_available():
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+
from torch.autograd import Function
# define a function that takes inputs val and grads
@@ -1160,6 +1176,8 @@ class TorchBackend(Backend):
return a.cpu().detach().numpy()
def from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
else: