summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-19 12:17:42 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-19 12:17:42 +0200
commit1086b8cad7c1ea2a02742dfc44aef036a674f5d3 (patch)
treeab2f94fd642048ec1b83a53c9b6e5fbe2b0c43b5 /src/python/test/test_wasserstein_distance.py
parentb2a9ba18ce33778abdd9f5032af4bfff04e8bbd2 (diff)
Test gradient
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 5bec5bd3..c6d6b346 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -90,3 +90,16 @@ def test_wasserstein_distance_pot():
def test_wasserstein_distance_hera():
_basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
_basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
+
+def test_wasserstein_distance_grad():
+ import torch
+
+ diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
+ diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ assert diag1.grad is None and diag2.grad is None and diag3.grad is None
+ dist1 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
+ dist2 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
+ dist1.backward()
+ dist2.backward()
+ assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()