summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2023-04-24 17:54:03 +0200
committerGitHub <noreply@github.com>2023-04-24 17:54:03 +0200
commit03ca4ef659a037e400975e3b2116b637a2d94265 (patch)
tree2fff6add4b430a9bb97cf594786777c7e48ea5a5 /test
parent25d72db09ed281c13b97aa8a68d82a4ed5ba7bf0 (diff)
[MRG] make alpha parameter in FGW diferentiable (#463)
* make alpha diferentiable * update release file * debug tensorflow to_numpy
Diffstat (limited to 'test')
-rw-r--r--test/test_gromov.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 80b6df4..f70f410 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -209,6 +209,8 @@ def test_gromov2_gradients():
if torch.cuda.is_available():
devices.append(torch.device("cuda"))
for device in devices:
+
+ # classical gradients
p1 = torch.tensor(p, requires_grad=True, device=device)
q1 = torch.tensor(q, requires_grad=True, device=device)
C11 = torch.tensor(C1, requires_grad=True, device=device)
@@ -226,6 +228,12 @@ def test_gromov2_gradients():
assert C12.shape == C12.grad.shape
# Test with armijo line-search
+ # classical gradients
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+
q1.grad = None
p1.grad = None
C11.grad = None
@@ -830,6 +838,25 @@ def test_fgw2_gradients():
assert C12.shape == C12.grad.shape
assert M1.shape == M1.grad.shape
+ # full gradients with alpha
+ p1 = torch.tensor(p, requires_grad=True, device=device)
+ q1 = torch.tensor(q, requires_grad=True, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+ alpha = torch.tensor(0.5, requires_grad=True, device=device)
+
+ val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1, alpha=alpha)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert alpha.shape == alpha.grad.shape
+
def test_fgw_helper_backend(nx):
n_samples = 20 # nb samples