From 726e84e1e9f2832ea5ad156f62a5e3636c1fd3d3 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 6 May 2022 13:34:18 +0200 Subject: [MRG] Torch random generator not working for Cuda tensor (#373) * Solve bug * Update release file --- RELEASES.md | 5 +++++ ot/backend.py | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 3461832..c06721f 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,11 @@ - Added Generalized Wasserstein Barycenter solver + example (PR #372) +#### Closed issues + +- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU + (Issue #371, PR #373) + ## 0.8.2 diff --git a/ot/backend.py b/ot/backend.py index 361ffba..e4b48e1 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1507,15 +1507,19 @@ class TorchBackend(Backend): def __init__(self): - self.rng_ = torch.Generator() + self.rng_ = torch.Generator("cpu") self.rng_.seed() self.__type_list__ = [torch.tensor(1, dtype=torch.float32), torch.tensor(1, dtype=torch.float64)] if torch.cuda.is_available(): + self.rng_cuda_ = torch.Generator("cuda") + self.rng_cuda_.seed() self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + else: + self.rng_cuda_ = torch.Generator("cpu") from torch.autograd import Function @@ -1761,20 +1765,26 @@ class TorchBackend(Backend): def seed(self, seed=None): if isinstance(seed, int): self.rng_.manual_seed(seed) + self.rng_cuda_.manual_seed(seed) elif isinstance(seed, torch.Generator): - self.rng_ = seed + if self.device_type(seed) == "GPU": + self.rng_cuda_ = seed + else: + self.rng_ = seed else: raise ValueError("Non compatible seed : {}".format(seed)) def rand(self, *size, type_as=None): if type_as is not None: - return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device) else: return torch.rand(size=size, generator=self.rng_) def randn(self, *size, type_as=None): if type_as is not None: - return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device) else: return torch.randn(size=size, generator=self.rng_) -- cgit v1.2.3