summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorSoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com>2023-06-12 10:12:01 +0200
committerGitHub <noreply@github.com>2023-06-12 10:12:01 +0200
commitf0dab2f684f4fc768fd50e0b70918e075dcdd0f3 (patch)
tree5fd86a99d40a5244ecb0d00e38b6d52dc5d2ef0e /test/test_gromov.py
parentf76dd53c2bdac86d2c4ed51e0be3d0169621fda5 (diff)
[FEAT] Alpha differentiability in semirelaxed_gromov_wasserstein2 (#483)
* alpha differentiable * autopep and update gradient test * debug test gradient --------- Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index f70f410..1beb818 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -1752,6 +1752,23 @@ def test_semirelaxed_fgw2_gradients():
assert C12.shape == C12.grad.shape
assert M1.shape == M1.grad.shape
+ # full gradients with alpha
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+ alpha = torch.tensor(0.5, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, alpha=alpha)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert alpha.shape == alpha.grad.shape
+
def test_srfgw_helper_backend(nx):
n_samples = 20 # nb samples