From 51f7b5bb15f351d08af4c26bd1ffdfe979199976 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Wed, 22 Apr 2020 16:29:26 +0200 Subject: Test value of computed gradient --- src/python/test/test_wasserstein_distance.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 6bfcb2ee..90d26809 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -105,8 +105,19 @@ def test_wasserstein_distance_grad(): diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) assert diag1.grad is None and diag2.grad is None and diag3.grad is None - dist1 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) - dist2 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) - dist1.backward() - dist2.backward() + dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) + dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist12.backward() + dist30.backward() assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() + diag4 = torch.tensor([[0., 10.]], requires_grad=True) + diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + assert dist45 == 3. + dist45.backward() + assert np.array_equal(diag4.grad, [[-1., -1.]]) + assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) + diag6 = torch.tensor([[5., 10.]], requires_grad=True) + pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() + # https://github.com/jonasrauber/eagerpy/issues/6 + # assert np.array_equal(diag6.grad, [[0., 0.]]) -- cgit v1.2.3