diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2023-04-24 17:54:03 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-24 17:54:03 +0200 |
commit | 03ca4ef659a037e400975e3b2116b637a2d94265 (patch) | |
tree | 2fff6add4b430a9bb97cf594786777c7e48ea5a5 /ot/gromov | |
parent | 25d72db09ed281c13b97aa8a68d82a4ed5ba7bf0 (diff) |
[MRG] make alpha parameter in FGW diferentiable (#463)
* make alpha diferentiable
* update release file
* debug tensorflow to_numpy
Diffstat (limited to 'ot/gromov')
-rw-r--r-- | ot/gromov/_gw.py | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index c6e4076..bc4719d 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -370,7 +370,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= Information and Inference: A Journal of the IMA, 8(4), 757-787. """ p, q = list_to_array(p, q) - p0, q0, C10, C20, M0 = p, q, C1, C2, M + p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha if G0 is None: nx = get_backend(p0, q0, C10, C20, M0) else: @@ -382,6 +382,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) M = nx.to_numpy(M0) + alpha = nx.to_numpy(alpha0) if symmetric is None: symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) @@ -535,10 +536,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T)) + if isinstance(alpha, int) or isinstance(alpha, float): + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T)) + else: + lin_term = nx.sum(T * M) + gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T, + gw_term - lin_term)) if log: return fgw_dist, log_fgw |