summaryrefslogtreecommitdiff
path: root/ot/gromov/_semirelaxed.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gromov/_semirelaxed.py')
-rw-r--r--ot/gromov/_semirelaxed.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py
index cb2bf28..94dc975 100644
--- a/ot/gromov/_semirelaxed.py
+++ b/ot/gromov/_semirelaxed.py
@@ -467,8 +467,15 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss',
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
- srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
- (alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ if isinstance(alpha, int) or isinstance(alpha, float):
+ srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
+ (alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ else:
+ lin_term = nx.sum(T * M)
+ srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
+ srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
+ (alpha * gC1, alpha * gC2, (1 - alpha) * T,
+ srgw_term - lin_term))
if log:
return srfgw_dist, log_fgw