From 9412f0ad1c0003e659b7d779bf8b6728e0e5e60f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 2 Mar 2022 11:35:47 +0100 Subject: [MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352) * Resolves gromov wasserstein backward bug * release file updated --- test/test_gromov.py | 60 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 25 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index 329f99c..0dcf2da 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,19 +181,24 @@ def test_gromov2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) - val = ot.gromov_wasserstein2(C11, C12, p1, q1) + val = ot.gromov_wasserstein2(C11, C12, p1, q1) - val.backward() + val.backward() - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -636,21 +641,26 @@ def test_fgw2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) - M1 = torch.tensor(M, requires_grad=True) - - val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) - - val.backward() - - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape def test_fgw_barycenter(nx): -- cgit v1.2.3