diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2022-05-06 13:34:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-06 13:34:18 +0200 |
commit | 726e84e1e9f2832ea5ad156f62a5e3636c1fd3d3 (patch) | |
tree | c98540ab45c22d139912d95a19bbada1ca4b286a /ot/backend.py | |
parent | ccc076e0fc535b2c734214c0ac1936e9e2cbeb62 (diff) |
[MRG] Torch random generator not working for Cuda tensor (#373)
* Solve bug
* Update release file
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/ot/backend.py b/ot/backend.py index 361ffba..e4b48e1 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1507,15 +1507,19 @@ class TorchBackend(Backend): def __init__(self): - self.rng_ = torch.Generator() + self.rng_ = torch.Generator("cpu") self.rng_.seed() self.__type_list__ = [torch.tensor(1, dtype=torch.float32), torch.tensor(1, dtype=torch.float64)] if torch.cuda.is_available(): + self.rng_cuda_ = torch.Generator("cuda") + self.rng_cuda_.seed() self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + else: + self.rng_cuda_ = torch.Generator("cpu") from torch.autograd import Function @@ -1761,20 +1765,26 @@ class TorchBackend(Backend): def seed(self, seed=None): if isinstance(seed, int): self.rng_.manual_seed(seed) + self.rng_cuda_.manual_seed(seed) elif isinstance(seed, torch.Generator): - self.rng_ = seed + if self.device_type(seed) == "GPU": + self.rng_cuda_ = seed + else: + self.rng_ = seed else: raise ValueError("Non compatible seed : {}".format(seed)) def rand(self, *size, type_as=None): if type_as is not None: - return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device) else: return torch.rand(size=size, generator=self.rng_) def randn(self, *size, type_as=None): if type_as is not None: - return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device) else: return torch.randn(size=size, generator=self.rng_) |