From f0dab2f684f4fc768fd50e0b70918e075dcdd0f3 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Mon, 12 Jun 2023 10:12:01 +0200 Subject: [FEAT] Alpha differentiability in semirelaxed_gromov_wasserstein2 (#483) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * alpha differentiable * autopep and update gradient test * debug test gradient --------- Co-authored-by: RĂ©mi Flamary --- ot/gromov/_semirelaxed.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'ot/gromov/_semirelaxed.py') diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index cb2bf28..94dc975 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -467,8 +467,15 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), - (alpha * gC1, alpha * gC2, (1 - alpha) * T)) + if isinstance(alpha, int) or isinstance(alpha, float): + srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), + (alpha * gC1, alpha * gC2, (1 - alpha) * T)) + else: + lin_term = nx.sum(T * M) + srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha + srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha), + (alpha * gC1, alpha * gC2, (1 - alpha) * T, + srgw_term - lin_term)) if log: return srfgw_dist, log_fgw -- cgit v1.2.3