diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2022-03-02 11:35:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-02 11:35:47 +0100 |
commit | 9412f0ad1c0003e659b7d779bf8b6728e0e5e60f (patch) | |
tree | 452ac5ace351a295e8e5c36224b4f69b7153fed6 | |
parent | 17814726200b4010afbf52701e8bcb132d678502 (diff) |
[MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352)
* Resolves gromov wasserstein backward bug
* release file updated
-rw-r--r-- | RELEASES.md | 3 | ||||
-rw-r--r-- | ot/gromov.py | 12 | ||||
-rw-r--r-- | test/test_gromov.py | 60 |
3 files changed, 46 insertions, 29 deletions
diff --git a/RELEASES.md b/RELEASES.md index c1068f3..18562e7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -18,6 +18,9 @@ - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) +- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA + tensors (Issue #351, PR #352) + ## 0.8.1.0 *December 2021* diff --git a/ot/gromov.py b/ot/gromov.py index f5a1f91..c5a82d1 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -546,8 +546,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gw = log_gw['gw_dist']
if loss_fun == 'square_loss':
- gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
- gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
+ gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
+ gC1 = nx.from_numpy(gC1, type_as=C10)
+ gC2 = nx.from_numpy(gC2, type_as=C10)
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
(log_gw['u'], log_gw['v'], gC1, gC2))
@@ -786,8 +788,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 log_fgw['T'] = T0
if loss_fun == 'square_loss':
- gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
- gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
+ gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
+ gC1 = nx.from_numpy(gC1, type_as=C10)
+ gC2 = nx.from_numpy(gC2, type_as=C10)
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
(log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
diff --git a/test/test_gromov.py b/test/test_gromov.py index 329f99c..0dcf2da 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,19 +181,24 @@ def test_gromov2_gradients(): if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
- val = ot.gromov_wasserstein2(C11, C12, p1, q1)
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1)
- val.backward()
+ val.backward()
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@@ -636,21 +641,26 @@ def test_fgw2_gradients(): if torch:
- p1 = torch.tensor(p, requires_grad=True)
- q1 = torch.tensor(q, requires_grad=True)
- C11 = torch.tensor(C1, requires_grad=True)
- C12 = torch.tensor(C2, requires_grad=True)
- M1 = torch.tensor(M, requires_grad=True)
-
- val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
-
- val.backward()
-
- assert q1.shape == q1.grad.shape
- assert p1.shape == p1.grad.shape
- assert C11.shape == C11.grad.shape
- assert C12.shape == C12.grad.shape
- assert M1.shape == M1.grad.shape
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
def test_fgw_barycenter(nx):
|