summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-09-29 13:23:56 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-09-29 13:23:56 +0200
commite7b7947adf13ec1dcb8c126a4373fa29baaecb63 (patch)
treedcc3eea7fd2de15ba2eb6e6da91f115c38a8519b
parenta304049bdcfb03aa848d8049923ab796e0761b56 (diff)
Added tests for wasserstein distance with tensorflow
-rw-r--r--src/cmake/modules/GUDHI_third_party_libraries.cmake1
-rw-r--r--src/python/CMakeLists.txt8
-rw-r--r--src/python/doc/installation.rst5
-rwxr-xr-xsrc/python/test/test_wasserstein_with_tensors.py25
4 files changed, 39 insertions, 0 deletions
diff --git a/src/cmake/modules/GUDHI_third_party_libraries.cmake b/src/cmake/modules/GUDHI_third_party_libraries.cmake
index 1fbc4244..9dadac4f 100644
--- a/src/cmake/modules/GUDHI_third_party_libraries.cmake
+++ b/src/cmake/modules/GUDHI_third_party_libraries.cmake
@@ -155,6 +155,7 @@ if( PYTHONINTERP_FOUND )
find_python_module("pykeops")
find_python_module("eagerpy")
find_python_module_no_version("hnswlib")
+ find_python_module("tensorflow")
endif()
if(NOT GUDHI_PYTHON_PATH)
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 4f26481e..cc71503f 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -103,6 +103,9 @@ if(PYTHONINTERP_FOUND)
if(EAGERPY_FOUND)
add_gudhi_debug_info("EagerPy version ${EAGERPY_VERSION}")
endif()
+ if(TENSORFLOW_FOUND)
+ add_gudhi_debug_info("TensorFlow version ${TENSORFLOW_VERSION}")
+ endif()
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_RESULT_OF_USE_DECLTYPE', ")
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_ALL_NO_LIB', ")
@@ -501,6 +504,11 @@ if(PYTHONINTERP_FOUND)
endif()
add_gudhi_py_test(test_wasserstein_barycenter)
endif()
+ if(OT_FOUND)
+ if(TENSORFLOW_FOUND AND EAGERPY_FOUND)
+ add_gudhi_py_test(test_wasserstein_with_tensors)
+ endif()
+ endif()
# Representations
if(SKLEARN_FOUND AND MATPLOTLIB_FOUND)
diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst
index 2f161d66..66efe45a 100644
--- a/src/python/doc/installation.rst
+++ b/src/python/doc/installation.rst
@@ -394,6 +394,11 @@ mathematics, science, and engineering.
:class:`~gudhi.point_cloud.knn.KNearestNeighbors` can use the Python package
`SciPy <http://scipy.org>`_ as a backend if explicitly requested.
+TensorFlow
+----------
+
+`TensorFlow <https://www.tensorflow.org>`_ is currently only used in some automatic differentiation tests.
+
Bug reports and contributions
*****************************
diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py
new file mode 100755
index 00000000..8957705d
--- /dev/null
+++ b/src/python/test/test_wasserstein_with_tensors.py
@@ -0,0 +1,25 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Mathieu Carriere
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.wasserstein import wasserstein_distance as pot
+import numpy as np
+
+def test_wasserstein_distance_grad_tensorflow():
+ import tensorflow as tf
+
+ with tf.GradientTape() as tape:
+ diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True))
+ diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True))
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+
+ grads = tape.gradient(dist45, [diag4, diag5])
+ assert np.array_equal(grads[0].values, [[-1., -1.]])
+ assert np.array_equal(grads[1].values, [[1., 1.], [-1., 1.]]) \ No newline at end of file