summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py38
1 files changed, 34 insertions, 4 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 1a4acc1d..90d26809 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -80,14 +80,44 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
-def hera_wrap(delta):
+def hera_wrap(**extra):
def fun(*kargs,**kwargs):
- return hera(*kargs,**kwargs,delta=delta)
+ return hera(*kargs,**kwargs,**extra)
+ return fun
+
+def pot_wrap(**extra):
+ def fun(*kargs,**kwargs):
+ return pot(*kargs,**kwargs,**extra)
return fun
def test_wasserstein_distance_pot():
_basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
+ _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False)
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False)
- _basic_wasserstein(hera_wrap(.1), .1, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
+
+def test_wasserstein_distance_grad():
+ import torch
+
+ diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
+ diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ assert diag1.grad is None and diag2.grad is None and diag3.grad is None
+ dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
+ dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
+ dist12.backward()
+ dist30.backward()
+ assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
+ diag4 = torch.tensor([[0., 10.]], requires_grad=True)
+ diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+ dist45.backward()
+ assert np.array_equal(diag4.grad, [[-1., -1.]])
+ assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
+ diag6 = torch.tensor([[5., 10.]], requires_grad=True)
+ pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ # assert np.array_equal(diag6.grad, [[0., 0.]])