summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-10-05 11:12:44 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-10-05 11:12:44 +0200
commitf0beb329f5a1767e4e0a0575ef3e078bf4563a44 (patch)
tree9a6bf0f5ec838b714bdba99c84efb04e5b641cb6
parente7b7947adf13ec1dcb8c126a4373fa29baaecb63 (diff)
code review: move test_wasserstein_distance_grad from test_wasserstein_distance.py to test_wasserstein_with_tensors.py
-rw-r--r--src/python/CMakeLists.txt5
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py24
-rwxr-xr-xsrc/python/test/test_wasserstein_with_tensors.py26
3 files changed, 27 insertions, 28 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index cc71503f..c09996fe 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -499,13 +499,14 @@ if(PYTHONINTERP_FOUND)
# Wasserstein
if(OT_FOUND AND PYBIND11_FOUND)
- if(TORCH_FOUND AND EAGERPY_FOUND)
+ # EagerPy dependency because of enable_autodiff=True
+ if(EAGERPY_FOUND)
add_gudhi_py_test(test_wasserstein_distance)
endif()
add_gudhi_py_test(test_wasserstein_barycenter)
endif()
if(OT_FOUND)
- if(TENSORFLOW_FOUND AND EAGERPY_FOUND)
+ if(TORCH_FOUND AND TENSORFLOW_FOUND AND EAGERPY_FOUND)
add_gudhi_py_test(test_wasserstein_with_tensors)
endif()
endif()
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 90d26809..e3b521d6 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -97,27 +97,3 @@ def test_wasserstein_distance_pot():
def test_wasserstein_distance_hera():
_basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
_basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
-
-def test_wasserstein_distance_grad():
- import torch
-
- diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
- diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
- diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
- assert diag1.grad is None and diag2.grad is None and diag3.grad is None
- dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
- dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
- dist12.backward()
- dist30.backward()
- assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
- diag4 = torch.tensor([[0., 10.]], requires_grad=True)
- diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
- dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
- assert dist45 == 3.
- dist45.backward()
- assert np.array_equal(diag4.grad, [[-1., -1.]])
- assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
- diag6 = torch.tensor([[5., 10.]], requires_grad=True)
- pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
- # https://github.com/jonasrauber/eagerpy/issues/6
- # assert np.array_equal(diag6.grad, [[0., 0.]])
diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py
index 8957705d..e3f1411a 100755
--- a/src/python/test/test_wasserstein_with_tensors.py
+++ b/src/python/test/test_wasserstein_with_tensors.py
@@ -10,10 +10,32 @@
from gudhi.wasserstein import wasserstein_distance as pot
import numpy as np
+import torch
+import tensorflow as tf
-def test_wasserstein_distance_grad_tensorflow():
- import tensorflow as tf
+def test_wasserstein_distance_grad():
+ diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
+ diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ assert diag1.grad is None and diag2.grad is None and diag3.grad is None
+ dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
+ dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
+ dist12.backward()
+ dist30.backward()
+ assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
+ diag4 = torch.tensor([[0., 10.]], requires_grad=True)
+ diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+ dist45.backward()
+ assert np.array_equal(diag4.grad, [[-1., -1.]])
+ assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
+ diag6 = torch.tensor([[5., 10.]], requires_grad=True)
+ pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ # assert np.array_equal(diag6.grad, [[0., 0.]])
+def test_wasserstein_distance_grad_tensorflow():
with tf.GradientTape() as tape:
diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True))
diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True))