summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-03-02 11:35:47 +0100
committerGitHub <noreply@github.com>2022-03-02 11:35:47 +0100
commit9412f0ad1c0003e659b7d779bf8b6728e0e5e60f (patch)
tree452ac5ace351a295e8e5c36224b4f69b7153fed6 /test
parent17814726200b4010afbf52701e8bcb132d678502 (diff)
[MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352)
* Resolves gromov wasserstein backward bug * release file updated
Diffstat (limited to 'test')
-rw-r--r--test/test_gromov.py60
1 files changed, 35 insertions, 25 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 329f99c..0dcf2da 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -181,19 +181,24 @@ def test_gromov2_gradients():
if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
- val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
- val.backward()
+ val.backward()
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@@ -636,21 +641,26 @@ def test_fgw2_gradients():
if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
- M1 = torch.tensor(M, requires_grad=True)
-
- val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
-
- val.backward()
-
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
- assert M1.shape == M1.grad.shape
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
def test_fgw_barycenter(nx):