summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-22 16:29:26 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-22 16:29:26 +0200
commit51f7b5bb15f351d08af4c26bd1ffdfe979199976 (patch)
tree443aae53f7172b140e4644378a32a6212e4f1939 /src/python/test/test_wasserstein_distance.py
parentda2a7a68f8f57495080af37cf981f64228d165a2 (diff)
Test value of computed gradient
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 6bfcb2ee..90d26809 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -105,8 +105,19 @@ def test_wasserstein_distance_grad():
diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
assert diag1.grad is None and diag2.grad is None and diag3.grad is None
- dist1 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
- dist2 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
- dist1.backward()
- dist2.backward()
+ dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
+ dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
+ dist12.backward()
+ dist30.backward()
assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
+ diag4 = torch.tensor([[0., 10.]], requires_grad=True)
+ diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+ dist45.backward()
+ assert np.array_equal(diag4.grad, [[-1., -1.]])
+ assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
+ diag6 = torch.tensor([[5., 10.]], requires_grad=True)
+ pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ # assert np.array_equal(diag6.grad, [[0., 0.]])