summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index cfccce7..80b6df4 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -214,6 +214,7 @@ def test_gromov2_gradients():
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
+ # Test with exact line-search
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
val.backward()
@@ -224,6 +225,21 @@ def test_gromov2_gradients():
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
+ # Test with armijo line-search
+ q1.grad = None
+ p1.grad = None
+ C11.grad = None
+ C12.grad = None
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
def test_gw_helper_backend(nx):
n_samples = 20 # nb samples