summaryrefslogtreecommitdiff
path: root/ot/gromov/_gw.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gromov/_gw.py')
-rw-r--r--ot/gromov/_gw.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py
index c6e4076..bc4719d 100644
--- a/ot/gromov/_gw.py
+++ b/ot/gromov/_gw.py
@@ -370,7 +370,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
Information and Inference: A Journal of the IMA, 8(4), 757-787.
"""
p, q = list_to_array(p, q)
- p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha
if G0 is None:
nx = get_backend(p0, q0, C10, C20, M0)
else:
@@ -382,6 +382,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
+ alpha = nx.to_numpy(alpha0)
if symmetric is None:
symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
@@ -535,10 +536,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
- fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
- (log_fgw['u'] - nx.mean(log_fgw['u']),
- log_fgw['v'] - nx.mean(log_fgw['v']),
- alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ if isinstance(alpha, int) or isinstance(alpha, float):
+ fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ else:
+ lin_term = nx.sum(T * M)
+ gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha
+ fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T,
+ gw_term - lin_term))
if log:
return fgw_dist, log_fgw