summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com>2023-06-12 10:12:01 +0200
committerGitHub <noreply@github.com>2023-06-12 10:12:01 +0200
commitf0dab2f684f4fc768fd50e0b70918e075dcdd0f3 (patch)
tree5fd86a99d40a5244ecb0d00e38b6d52dc5d2ef0e
parentf76dd53c2bdac86d2c4ed51e0be3d0169621fda5 (diff)
[FEAT] Alpha differentiability in semirelaxed_gromov_wasserstein2 (#483)
* alpha differentiable * autopep and update gradient test * debug test gradient --------- Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md1
-rw-r--r--ot/gromov/_semirelaxed.py11
-rw-r--r--test/test_gromov.py17
3 files changed, 27 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md
index cd0bcde..61ad2ca 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -3,6 +3,7 @@
## 0.9.1dev
#### New features
+- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
- Add tests on GPU for master branch and approved PR (PR #473)
diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py
index cb2bf28..94dc975 100644
--- a/ot/gromov/_semirelaxed.py
+++ b/ot/gromov/_semirelaxed.py
@@ -467,8 +467,15 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss',
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
- srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
- (alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ if isinstance(alpha, int) or isinstance(alpha, float):
+ srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
+ (alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ else:
+ lin_term = nx.sum(T * M)
+ srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
+ srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
+ (alpha * gC1, alpha * gC2, (1 - alpha) * T,
+ srgw_term - lin_term))
if log:
return srfgw_dist, log_fgw
diff --git a/test/test_gromov.py b/test/test_gromov.py
index f70f410..1beb818 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -1752,6 +1752,23 @@ def test_semirelaxed_fgw2_gradients():
assert C12.shape == C12.grad.shape
assert M1.shape == M1.grad.shape
+ # full gradients with alpha
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+ alpha = torch.tensor(0.5, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, alpha=alpha)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert alpha.shape == alpha.grad.shape
+
def test_srfgw_helper_backend(nx):
n_samples = 20 # nb samples