summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-12 15:54:26 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-12 15:54:26 +0200
commit2a11e3651c2d66df8371a9aa1d23dff69ffbc31c (patch)
tree44c8c150c4cf971c73955a0d5c41ccad161d3001 /src/python/test/test_wasserstein_distance.py
parent777522b82bde16b55f15c21471bad06038849fd1 (diff)
removed test_wasserstein_distance_grad to be consistent with master
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py23
1 files changed, 0 insertions, 23 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 12bf71df..14d5c2ca 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -159,26 +159,3 @@ def test_wasserstein_distance_hera():
_basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
_basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
-def test_wasserstein_distance_grad():
- import torch
-
- diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
- diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
- diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
- assert diag1.grad is None and diag2.grad is None and diag3.grad is None
- dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False)
- dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False)
- dist12.backward()
- dist30.backward()
- assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
- diag4 = torch.tensor([[0., 10.]], requires_grad=True)
- diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
- dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True, keep_essential_parts=False)
- assert dist45 == 3.
- dist45.backward()
- assert np.array_equal(diag4.grad, [[-1., -1.]])
- assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
- diag6 = torch.tensor([[5., 10.]], requires_grad=True)
- pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False).backward()
- # https://github.com/jonasrauber/eagerpy/issues/6
- # assert np.array_equal(diag6.grad, [[0., 0.]])