summaryrefslogtreecommitdiff
path: root/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-08-01 17:38:05 +0200
committerGitHub <noreply@github.com>2022-08-01 17:38:05 +0200
commit0138dcf636c3f3f0e63110b08a8249f065e1fa73 (patch)
tree704ed82dff4c5807ef983c64a60bad6a0e54ef3d /examples/backends/plot_sliced_wass_grad_flow_pytorch.py
parent818c7ace20da36d8042b0d7ad7a712b27f7afd59 (diff)
[MRG] Solve example throwing an error when executed on a GPU (#391)
* Solve example throwing an error when executed on a GPU * add PR to releases.md * update pep8 command * pep8
Diffstat (limited to 'examples/backends/plot_sliced_wass_grad_flow_pytorch.py')
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index 59e0042..f00de50 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -74,7 +74,7 @@ x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
loss_iter = []
# generator for random permutations
-gen = torch.Generator()
+gen = torch.Generator(device=device)
gen.manual_seed(42)
for i in range(nb_iter_max):
@@ -136,7 +136,7 @@ x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
loss_iter = []
# generator for random permutations
-gen = torch.Generator()
+gen = torch.Generator(device=device)
gen.manual_seed(42)
alpha = 0.5