diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2022-03-02 11:35:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-02 11:35:47 +0100 |
commit | 9412f0ad1c0003e659b7d779bf8b6728e0e5e60f (patch) | |
tree | 452ac5ace351a295e8e5c36224b4f69b7153fed6 /test | |
parent | 17814726200b4010afbf52701e8bcb132d678502 (diff) |
[MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352)
* Resolves gromov wasserstein backward bug
* release file updated
Diffstat (limited to 'test')
-rw-r--r-- | test/test_gromov.py | 60 |
1 files changed, 35 insertions, 25 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py index 329f99c..0dcf2da 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,19 +181,24 @@ def test_gromov2_gradients(): if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
- val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
- val.backward()
+ val.backward()
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@@ -636,21 +641,26 @@ def test_fgw2_gradients(): if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
- M1 = torch.tensor(M, requires_grad=True)
-
- val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
-
- val.backward()
-
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
- assert M1.shape == M1.grad.shape
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
def test_fgw_barycenter(nx):
|