summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-05-06 13:34:18 +0200
committerGitHub <noreply@github.com>2022-05-06 13:34:18 +0200
commit726e84e1e9f2832ea5ad156f62a5e3636c1fd3d3 (patch)
treec98540ab45c22d139912d95a19bbada1ca4b286a
parentccc076e0fc535b2c734214c0ac1936e9e2cbeb62 (diff)
[MRG] Torch random generator not working for Cuda tensor (#373)
* Solve bug * Update release file
-rw-r--r--RELEASES.md5
-rw-r--r--ot/backend.py18
2 files changed, 19 insertions, 4 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 3461832..c06721f 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -6,6 +6,11 @@
- Added Generalized Wasserstein Barycenter solver + example (PR #372)
+#### Closed issues
+
+- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
+ (Issue #371, PR #373)
+
## 0.8.2
diff --git a/ot/backend.py b/ot/backend.py
index 361ffba..e4b48e1 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1507,15 +1507,19 @@ class TorchBackend(Backend):
def __init__(self):
- self.rng_ = torch.Generator()
+ self.rng_ = torch.Generator("cpu")
self.rng_.seed()
self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
torch.tensor(1, dtype=torch.float64)]
if torch.cuda.is_available():
+ self.rng_cuda_ = torch.Generator("cuda")
+ self.rng_cuda_.seed()
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+ else:
+ self.rng_cuda_ = torch.Generator("cpu")
from torch.autograd import Function
@@ -1761,20 +1765,26 @@ class TorchBackend(Backend):
def seed(self, seed=None):
if isinstance(seed, int):
self.rng_.manual_seed(seed)
+ self.rng_cuda_.manual_seed(seed)
elif isinstance(seed, torch.Generator):
- self.rng_ = seed
+ if self.device_type(seed) == "GPU":
+ self.rng_cuda_ = seed
+ else:
+ self.rng_ = seed
else:
raise ValueError("Non compatible seed : {}".format(seed))
def rand(self, *size, type_as=None):
if type_as is not None:
- return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
+ generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
+ return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device)
else:
return torch.rand(size=size, generator=self.rng_)
def randn(self, *size, type_as=None):
if type_as is not None:
- return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
+ generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
+ return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device)
else:
return torch.randn(size=size, generator=self.rng_)