From 1086b8cad7c1ea2a02742dfc44aef036a674f5d3 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Sun, 19 Apr 2020 12:17:42 +0200 Subject: Test gradient --- src/python/test/test_wasserstein_distance.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'src') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 5bec5bd3..c6d6b346 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -90,3 +90,16 @@ def test_wasserstein_distance_pot(): def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) + +def test_wasserstein_distance_grad(): + import torch + + diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) + diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + assert diag1.grad is None and diag2.grad is None and diag3.grad is None + dist1 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) + dist2 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist1.backward() + dist2.backward() + assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() -- cgit v1.2.3