From 9412f0ad1c0003e659b7d779bf8b6728e0e5e60f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 2 Mar 2022 11:35:47 +0100 Subject: [MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352) * Resolves gromov wasserstein backward bug * release file updated --- ot/gromov.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index f5a1f91..c5a82d1 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -546,8 +546,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gw = log_gw['gw_dist'] if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) gw = nx.set_gradients(gw, (p0, q0, C10, C20), (log_gw['u'], log_gw['v'], gC1, gC2)) @@ -786,8 +788,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 log_fgw['T'] = T0 if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) -- cgit v1.2.3