summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-11-18 08:03:56 +0100
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-11-18 08:03:56 +0100
commit8b7a25482dfd9c38825e022d5f95135f0aade738 (patch)
treee986157f9921aa261a58c8d812f2802cab248310 /src/python
parentd33eaa80b7c337fde11bb5db60df79fbc81fb483 (diff)
parentad5d38986542715e0a0518537afaadcda71d9c49 (diff)
merge master and resolve conflicts
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt101
-rw-r--r--src/python/doc/bottleneck_distance_user.rst4
-rw-r--r--src/python/doc/installation.rst9
-rw-r--r--src/python/doc/rips_complex_sum.inc22
-rw-r--r--src/python/doc/rips_complex_user.rst6
-rwxr-xr-xsrc/python/example/alpha_complex_diagram_persistence_from_off_file_example.py2
-rwxr-xr-xsrc/python/example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py2
-rwxr-xr-xsrc/python/example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py2
-rwxr-xr-xsrc/python/example/periodic_cubical_complex_barcode_persistence_from_perseus_file_example.py2
-rwxr-xr-xsrc/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py2
-rwxr-xr-xsrc/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py2
-rwxr-xr-xsrc/python/example/rips_complex_diagram_persistence_from_off_file_example.py2
-rwxr-xr-xsrc/python/example/tangential_complex_plain_homology_from_off_file_example.py2
-rw-r--r--src/python/gudhi/representations/vector_methods.py17
-rw-r--r--src/python/gudhi/simplex_tree.pxd8
-rw-r--r--src/python/gudhi/simplex_tree.pyx38
-rw-r--r--src/python/gudhi/subsampling.pyx21
-rw-r--r--src/python/include/Simplex_tree_interface.h10
-rwxr-xr-xsrc/python/test/test_bottleneck_distance.py12
-rwxr-xr-xsrc/python/test/test_representations.py20
-rwxr-xr-xsrc/python/test/test_simplex_tree.py45
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py24
-rwxr-xr-xsrc/python/test/test_wasserstein_with_tensors.py47
23 files changed, 269 insertions, 131 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 4f26481e..56b6876c 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -103,6 +103,9 @@ if(PYTHONINTERP_FOUND)
if(EAGERPY_FOUND)
add_gudhi_debug_info("EagerPy version ${EAGERPY_VERSION}")
endif()
+ if(TENSORFLOW_FOUND)
+ add_gudhi_debug_info("TensorFlow version ${TENSORFLOW_VERSION}")
+ endif()
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_RESULT_OF_USE_DECLTYPE', ")
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_ALL_NO_LIB', ")
@@ -342,29 +345,27 @@ if(PYTHONINTERP_FOUND)
COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_rips_persistence_bottleneck_distance.py"
-f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -t 0.15 -d 3)
- if(MATPLOTLIB_FOUND AND NUMPY_FOUND)
- # Tangential
- add_test(NAME tangential_complex_plain_homology_from_off_file_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/tangential_complex_plain_homology_from_off_file_example.py"
- --no-diagram -i 2 -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off)
-
- add_gudhi_py_test(test_tangential_complex)
-
- # Witness complex AND Subsampling
- add_test(NAME euclidean_strong_witness_complex_diagram_persistence_from_off_file_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py"
- --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 1.0 -n 20 -d 2)
-
- add_test(NAME euclidean_witness_complex_diagram_persistence_from_off_file_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py"
- --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 1.0 -n 20 -d 2)
- endif()
+ # Tangential
+ add_test(NAME tangential_complex_plain_homology_from_off_file_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/tangential_complex_plain_homology_from_off_file_example.py"
+ --no-diagram -i 2 -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off)
+
+ add_gudhi_py_test(test_tangential_complex)
+
+ # Witness complex
+ add_test(NAME euclidean_strong_witness_complex_diagram_persistence_from_off_file_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py"
+ --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 1.0 -n 20 -d 2)
+
+ add_test(NAME euclidean_witness_complex_diagram_persistence_from_off_file_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py"
+ --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 1.0 -n 20 -d 2)
# Subsampling
add_gudhi_py_test(test_subsampling)
@@ -419,13 +420,11 @@ if(PYTHONINTERP_FOUND)
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_complex_from_points_example.py")
- if(MATPLOTLIB_FOUND AND NUMPY_FOUND)
- add_test(NAME alpha_complex_diagram_persistence_from_off_file_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_complex_diagram_persistence_from_off_file_example.py"
- --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 0.6)
- endif()
+ add_test(NAME alpha_complex_diagram_persistence_from_off_file_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/alpha_complex_diagram_persistence_from_off_file_example.py"
+ --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -a 0.6)
add_gudhi_py_test(test_alpha_complex)
endif (NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
@@ -442,30 +441,26 @@ if(PYTHONINTERP_FOUND)
${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/periodic_cubical_complex_barcode_persistence_from_perseus_file_example.py"
--no-barcode -f ${CMAKE_SOURCE_DIR}/data/bitmap/CubicalTwoSphere.txt)
- if(NUMPY_FOUND)
- add_test(NAME random_cubical_complex_persistence_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/random_cubical_complex_persistence_example.py"
- 10 10 10)
- endif()
+ add_test(NAME random_cubical_complex_persistence_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/random_cubical_complex_persistence_example.py"
+ 10 10 10)
add_gudhi_py_test(test_cubical_complex)
# Rips
- if(MATPLOTLIB_FOUND AND NUMPY_FOUND)
- add_test(NAME rips_complex_diagram_persistence_from_distance_matrix_file_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py"
- --no-diagram -f ${CMAKE_SOURCE_DIR}/data/distance_matrix/lower_triangular_distance_matrix.csv -e 12.0 -d 3)
+ add_test(NAME rips_complex_diagram_persistence_from_distance_matrix_file_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py"
+ --no-diagram -f ${CMAKE_SOURCE_DIR}/data/distance_matrix/lower_triangular_distance_matrix.csv -e 12.0 -d 3)
- add_test(NAME rips_complex_diagram_persistence_from_off_file_example_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/example/rips_complex_diagram_persistence_from_off_file_example.py
- --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -e 0.25 -d 3)
- endif()
+ add_test(NAME rips_complex_diagram_persistence_from_off_file_example_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/example/rips_complex_diagram_persistence_from_off_file_example.py
+ --no-diagram -f ${CMAKE_SOURCE_DIR}/data/points/tore3D_300.off -e 0.25 -d 3)
add_test(NAME rips_complex_from_points_example_py_test
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
@@ -496,11 +491,17 @@ if(PYTHONINTERP_FOUND)
# Wasserstein
if(OT_FOUND AND PYBIND11_FOUND)
- if(TORCH_FOUND AND EAGERPY_FOUND)
+ # EagerPy dependency because of enable_autodiff=True
+ if(EAGERPY_FOUND)
add_gudhi_py_test(test_wasserstein_distance)
endif()
add_gudhi_py_test(test_wasserstein_barycenter)
endif()
+ if(OT_FOUND)
+ if(TORCH_FOUND AND TENSORFLOW_FOUND AND EAGERPY_FOUND)
+ add_gudhi_py_test(test_wasserstein_with_tensors)
+ endif()
+ endif()
# Representations
if(SKLEARN_FOUND AND MATPLOTLIB_FOUND)
diff --git a/src/python/doc/bottleneck_distance_user.rst b/src/python/doc/bottleneck_distance_user.rst
index 6c6e08d9..7baa76cc 100644
--- a/src/python/doc/bottleneck_distance_user.rst
+++ b/src/python/doc/bottleneck_distance_user.rst
@@ -47,7 +47,7 @@ The following example explains how the distance is computed:
:figclass: align-center
The point (0, 13) is at distance 6.5 from the diagonal and more
- specifically from the point (6.5, 6.5)
+ specifically from the point (6.5, 6.5).
Basic example
@@ -72,6 +72,6 @@ The output is:
.. testoutput::
- Bottleneck distance approximation = 0.81
+ Bottleneck distance approximation = 0.72
Bottleneck distance value = 0.75
diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst
index 78e1af73..66efe45a 100644
--- a/src/python/doc/installation.rst
+++ b/src/python/doc/installation.rst
@@ -40,7 +40,7 @@ different, and in particular the `python/` subdirectory is actually `src/python/
there.
The library uses c++14 and requires `Boost <https://www.boost.org/>`_ :math:`\geq` 1.56.0,
-`CMake <https://www.cmake.org/>`_ :math:`\geq` 3.1 to generate makefiles,
+`CMake <https://www.cmake.org/>`_ :math:`\geq` 3.5 to generate makefiles,
`NumPy <http://numpy.org>`_, `Cython <https://www.cython.org/>`_ and
`pybind11 <https://github.com/pybind/pybind11>`_ to compile
the GUDHI Python module.
@@ -65,7 +65,7 @@ one can build the GUDHI Python module, by running the following commands in a te
cd /path-to-gudhi/
mkdir build
cd build/
- cmake ..
+ cmake -DCMAKE_BUILD_TYPE=Release ..
cd python
make
@@ -394,6 +394,11 @@ mathematics, science, and engineering.
:class:`~gudhi.point_cloud.knn.KNearestNeighbors` can use the Python package
`SciPy <http://scipy.org>`_ as a backend if explicitly requested.
+TensorFlow
+----------
+
+`TensorFlow <https://www.tensorflow.org>`_ is currently only used in some automatic differentiation tests.
+
Bug reports and contributions
*****************************
diff --git a/src/python/doc/rips_complex_sum.inc b/src/python/doc/rips_complex_sum.inc
index c123ea2a..2cb24990 100644
--- a/src/python/doc/rips_complex_sum.inc
+++ b/src/python/doc/rips_complex_sum.inc
@@ -1,14 +1,14 @@
.. table::
:widths: 30 40 30
- +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------+
- | .. figure:: | The Vietoris-Rips complex is a simplicial complex built as the | :Authors: Clément Maria, Pawel Dlotko, Vincent Rouvreau, Marc Glisse |
- | ../../doc/Rips_complex/rips_complex_representation.png | clique-complex of a proximity graph. | |
- | :figclass: align-center | | :Since: GUDHI 2.0.0 |
- | | We also provide sparse approximations, to speed-up the computation | |
- | | of persistent homology, and weighted versions, which are more robust | :License: MIT |
- | | to outliers. | |
- | | | |
- +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------+
- | * :doc:`rips_complex_user` | * :doc:`rips_complex_ref` |
- +----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+
+ +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
+ | .. figure:: | The Vietoris-Rips complex is a simplicial complex built as the | :Authors: Clément Maria, Pawel Dlotko, Vincent Rouvreau, Marc Glisse, Yuichi Ike |
+ | ../../doc/Rips_complex/rips_complex_representation.png | clique-complex of a proximity graph. | |
+ | :figclass: align-center | | :Since: GUDHI 2.0.0 |
+ | | We also provide sparse approximations, to speed-up the computation | |
+ | | of persistent homology, and weighted versions, which are more robust | :License: MIT |
+ | | to outliers. | |
+ | | | |
+ +----------------------------------------------------------------+------------------------------------------------------------------------+----------------------------------------------------------------------------------+
+ | * :doc:`rips_complex_user` | * :doc:`rips_complex_ref` |
+ +----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/rips_complex_user.rst b/src/python/doc/rips_complex_user.rst
index 6048cc4e..27d218d4 100644
--- a/src/python/doc/rips_complex_user.rst
+++ b/src/python/doc/rips_complex_user.rst
@@ -7,9 +7,9 @@ Rips complex user manual
Definition
----------
-==================================================================== ================================ ======================
-:Authors: Clément Maria, Pawel Dlotko, Vincent Rouvreau, Marc Glisse :Since: GUDHI 2.0.0 :License: GPL v3
-==================================================================== ================================ ======================
+================================================================================ ================================ ======================
+:Authors: Clément Maria, Pawel Dlotko, Vincent Rouvreau, Marc Glisse, Yuichi Ike :Since: GUDHI 2.0.0 :License: GPL v3
+================================================================================ ================================ ======================
+-------------------------------------------+----------------------------------------------------------------------+
| :doc:`rips_complex_user` | :doc:`rips_complex_ref` |
diff --git a/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py b/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
index 727af4fa..1e0273b3 100755
--- a/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
+++ b/src/python/example/alpha_complex_diagram_persistence_from_off_file_example.py
@@ -3,7 +3,6 @@
import argparse
import errno
import os
-import matplotlib.pyplot as plot
import gudhi
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
@@ -65,6 +64,7 @@ with open(args.file, "r") as f:
print(simplex_tree.betti_numbers())
if args.no_diagram == False:
+ import matplotlib.pyplot as plot
gudhi.plot_persistence_diagram(diag, band=args.band)
plot.show()
else:
diff --git a/src/python/example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py b/src/python/example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py
index e1e572df..4e97cfe3 100755
--- a/src/python/example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py
+++ b/src/python/example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py
@@ -3,7 +3,6 @@
import argparse
import errno
import os
-import matplotlib.pyplot as plot
import gudhi
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
@@ -82,6 +81,7 @@ with open(args.file, "r") as f:
print(simplex_tree.betti_numbers())
if args.no_diagram == False:
+ import matplotlib.pyplot as plot
gudhi.plot_persistence_diagram(diag, band=args.band)
plot.show()
else:
diff --git a/src/python/example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py b/src/python/example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py
index 58cb2bb5..29076c74 100755
--- a/src/python/example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py
+++ b/src/python/example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py
@@ -3,7 +3,6 @@
import argparse
import errno
import os
-import matplotlib.pyplot as plot
import gudhi
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
@@ -79,6 +78,7 @@ with open(args.file, "r") as f:
print(simplex_tree.betti_numbers())
if args.no_diagram == False:
+ import matplotlib.pyplot as plot
gudhi.plot_persistence_diagram(diag, band=args.band)
plot.show()
else:
diff --git a/src/python/example/periodic_cubical_complex_barcode_persistence_from_perseus_file_example.py b/src/python/example/periodic_cubical_complex_barcode_persistence_from_perseus_file_example.py
index 499171df..ee3290c6 100755
--- a/src/python/example/periodic_cubical_complex_barcode_persistence_from_perseus_file_example.py
+++ b/src/python/example/periodic_cubical_complex_barcode_persistence_from_perseus_file_example.py
@@ -1,7 +1,6 @@
#!/usr/bin/env python
import argparse
-import matplotlib.pyplot as plot
import errno
import os
import gudhi
@@ -75,6 +74,7 @@ if is_file_perseus(args.file):
print("betti_numbers()=")
print(periodic_cubical_complex.betti_numbers())
if args.no_barcode == False:
+ import matplotlib.pyplot as plot
gudhi.plot_persistence_barcode(diag)
plot.show()
else:
diff --git a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
index 1acb187c..ea2eb7e1 100755
--- a/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
+++ b/src/python/example/rips_complex_diagram_persistence_from_correlation_matrix_file_example.py
@@ -2,7 +2,6 @@
import sys
import argparse
-import matplotlib.pyplot as plot
import gudhi
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
@@ -84,5 +83,6 @@ invert_diag = [
]
if args.no_diagram == False:
+ import matplotlib.pyplot as plot
gudhi.plot_persistence_diagram(invert_diag, band=args.band)
plot.show()
diff --git a/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py
index 79ccca96..236d085d 100755
--- a/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py
+++ b/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py
@@ -1,7 +1,6 @@
#!/usr/bin/env python
import argparse
-import matplotlib.pyplot as plot
import gudhi
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
@@ -60,5 +59,6 @@ print("betti_numbers()=")
print(simplex_tree.betti_numbers())
if args.no_diagram == False:
+ import matplotlib.pyplot as plot
gudhi.plot_persistence_diagram(diag, band=args.band)
plot.show()
diff --git a/src/python/example/rips_complex_diagram_persistence_from_off_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_off_file_example.py
index 6f992508..e80233a9 100755
--- a/src/python/example/rips_complex_diagram_persistence_from_off_file_example.py
+++ b/src/python/example/rips_complex_diagram_persistence_from_off_file_example.py
@@ -3,7 +3,6 @@
import argparse
import errno
import os
-import matplotlib.pyplot as plot
import gudhi
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
@@ -70,6 +69,7 @@ with open(args.file, "r") as f:
print(simplex_tree.betti_numbers())
if args.no_diagram == False:
+ import matplotlib.pyplot as plot
gudhi.plot_persistence_diagram(diag, band=args.band)
plot.show()
else:
diff --git a/src/python/example/tangential_complex_plain_homology_from_off_file_example.py b/src/python/example/tangential_complex_plain_homology_from_off_file_example.py
index 85bade4a..a4b4e9f5 100755
--- a/src/python/example/tangential_complex_plain_homology_from_off_file_example.py
+++ b/src/python/example/tangential_complex_plain_homology_from_off_file_example.py
@@ -3,7 +3,6 @@
import argparse
import errno
import os
-import matplotlib.pyplot as plot
import gudhi
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ -
@@ -62,6 +61,7 @@ with open(args.file, "r") as f:
print(st.betti_numbers())
if args.no_diagram == False:
+ import matplotlib.pyplot as plot
gudhi.plot_persistence_diagram(diag, band=args.band)
plot.show()
else:
diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py
index 5ca127f6..cdcb1fde 100644
--- a/src/python/gudhi/representations/vector_methods.py
+++ b/src/python/gudhi/representations/vector_methods.py
@@ -323,22 +323,15 @@ class BettiCurve(BaseEstimator, TransformerMixin):
Returns:
numpy array with shape (number of diagrams) x (**resolution**): output Betti curves.
"""
- num_diag, Xfit = len(X), []
+ Xfit = []
x_values = np.linspace(self.sample_range[0], self.sample_range[1], self.resolution)
step_x = x_values[1] - x_values[0]
- for i in range(num_diag):
-
- diagram, num_pts_in_diag = X[i], X[i].shape[0]
-
+ for diagram in X:
+ diagram_int = np.clip(np.ceil((diagram[:,:2] - self.sample_range[0]) / step_x), 0, self.resolution).astype(int)
bc = np.zeros(self.resolution)
- for j in range(num_pts_in_diag):
- [px,py] = diagram[j,:2]
- min_idx = np.clip(np.ceil((px - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- max_idx = np.clip(np.ceil((py - self.sample_range[0]) / step_x).astype(int), 0, self.resolution)
- for k in range(min_idx, max_idx):
- bc[k] += 1
-
+ for interval in diagram_int:
+ bc[interval[0]:interval[1]] += 1
Xfit.append(np.reshape(bc,[1,-1]))
Xfit = np.concatenate(Xfit, 0)
diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd
index 75e94e0b..3c4cbed3 100644
--- a/src/python/gudhi/simplex_tree.pxd
+++ b/src/python/gudhi/simplex_tree.pxd
@@ -36,6 +36,12 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
Simplex_tree_skeleton_iterator operator++() nogil
bint operator!=(Simplex_tree_skeleton_iterator) nogil
+ cdef cppclass Simplex_tree_boundary_iterator "Gudhi::Simplex_tree_interface<Gudhi::Simplex_tree_options_full_featured>::Boundary_simplex_iterator":
+ Simplex_tree_boundary_iterator() nogil
+ Simplex_tree_simplex_handle& operator*() nogil
+ Simplex_tree_boundary_iterator operator++() nogil
+ bint operator!=(Simplex_tree_boundary_iterator) nogil
+
cdef cppclass Simplex_tree_interface_full_featured "Gudhi::Simplex_tree_interface<Gudhi::Simplex_tree_options_full_featured>":
Simplex_tree() nogil
@@ -58,6 +64,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
void compute_extended_filtration() nogil
vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil
Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil
+ void reset_filtration(double filtration, int dimension) nogil
# Iterators over Simplex tree
pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) nogil
Simplex_tree_simplices_iterator get_simplices_iterator_begin() nogil
@@ -66,6 +73,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
vector[Simplex_tree_simplex_handle].const_iterator get_filtration_iterator_end() nogil
Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil
Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil
+ pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] get_boundary_iterators(vector[int] simplex) nogil except +
cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi":
cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface<Gudhi::Simplex_tree<Gudhi::Simplex_tree_options_full_featured>>":
diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx
index 92645ffc..cdd2e87b 100644
--- a/src/python/gudhi/simplex_tree.pyx
+++ b/src/python/gudhi/simplex_tree.pyx
@@ -285,6 +285,22 @@ cdef class SimplexTree:
ct.append((v, filtered_simplex.second))
return ct
+ def get_boundaries(self, simplex):
+ """This function returns a generator with the boundaries of a given N-simplex.
+ If you do not need the filtration values, the boundary can also be obtained as
+ :code:`itertools.combinations(simplex,len(simplex)-1)`.
+
+ :param simplex: The N-simplex, represented by a list of vertex.
+ :type simplex: list of int.
+ :returns: The (simplices of the) boundary of a simplex
+ :rtype: generator with tuples(simplex, filtration)
+ """
+ cdef pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] it = self.get_ptr().get_boundary_iterators(simplex)
+
+ while it.first != it.second:
+ yield self.get_ptr().get_simplex_and_filtration(dereference(it.first))
+ preincrement(it.first)
+
def remove_maximal_simplex(self, simplex):
"""This function removes a given maximal N-simplex from the simplicial
complex.
@@ -328,7 +344,7 @@ cdef class SimplexTree:
return self.get_ptr().prune_above_filtration(filtration)
def expansion(self, max_dim):
- """Expands the Simplex_tree containing only its one skeleton
+ """Expands the simplex tree containing only its one skeleton
until dimension max_dim.
The expanded simplicial complex until dimension :math:`d`
@@ -338,7 +354,7 @@ cdef class SimplexTree:
The filtration value assigned to a simplex is the maximal filtration
value of one of its edges.
- The Simplex_tree must contain no simplex of dimension bigger than
+ The simplex tree must contain no simplex of dimension bigger than
1 when calling the method.
:param max_dim: The maximal dimension.
@@ -358,6 +374,20 @@ cdef class SimplexTree:
"""
return self.get_ptr().make_filtration_non_decreasing()
+ def reset_filtration(self, filtration, min_dim = 0):
+ """This function resets the filtration value of all the simplices of dimension at least min_dim. Resets all the
+ simplex tree when `min_dim = 0`.
+ `reset_filtration` may break the filtration property with `min_dim > 0`, and it is the user's responsibility to
+ make it a valid filtration (using a large enough `filt_value`, or calling `make_filtration_non_decreasing`
+ afterwards for instance).
+
+ :param filtration: New threshold value.
+ :type filtration: float.
+ :param min_dim: The minimal dimension. Default value is 0.
+ :type min_dim: int.
+ """
+ self.get_ptr().reset_filtration(filtration, min_dim)
+
def extend_filtration(self):
""" Extend filtration for computing extended persistence. This function only uses the filtration values at the
0-dimensional simplices, and computes the extended persistence diagram induced by the lower-star filtration
@@ -365,12 +395,12 @@ cdef class SimplexTree:
.. note::
- Note that after calling this function, the filtration values are actually modified within the Simplex_tree.
+ Note that after calling this function, the filtration values are actually modified within the simplex tree.
The function :func:`extended_persistence` retrieves the original values.
.. note::
- Note that this code creates an extra vertex internally, so you should make sure that the Simplex_tree does
+ Note that this code creates an extra vertex internally, so you should make sure that the simplex tree does
not contain a vertex with the largest possible value (i.e., 4294967295).
This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-extended-persistence.ipynb>`_
diff --git a/src/python/gudhi/subsampling.pyx b/src/python/gudhi/subsampling.pyx
index f77c6f75..b11d07e5 100644
--- a/src/python/gudhi/subsampling.pyx
+++ b/src/python/gudhi/subsampling.pyx
@@ -33,7 +33,7 @@ def choose_n_farthest_points(points=None, off_file='', nb_points=0, starting_poi
The iteration starts with the landmark `starting point`.
:param points: The input point set.
- :type points: Iterable[Iterable[float]].
+ :type points: Iterable[Iterable[float]]
Or
@@ -42,14 +42,15 @@ def choose_n_farthest_points(points=None, off_file='', nb_points=0, starting_poi
And in both cases
- :param nb_points: Number of points of the subsample.
- :type nb_points: unsigned.
+ :param nb_points: Number of points of the subsample (the subsample may be \
+ smaller if there are fewer than nb_points distinct input points)
+ :type nb_points: int
:param starting_point: The iteration starts with the landmark `starting \
- point`,which is the index of the point to start with. If not set, this \
+ point`, which is the index of the point to start with. If not set, this \
index is chosen randomly.
- :type starting_point: unsigned.
+ :type starting_point: int
:returns: The subsample point set.
- :rtype: List[List[float]].
+ :rtype: List[List[float]]
"""
if off_file:
if os.path.isfile(off_file):
@@ -76,7 +77,7 @@ def pick_n_random_points(points=None, off_file='', nb_points=0):
"""Subsample a point set by picking random vertices.
:param points: The input point set.
- :type points: Iterable[Iterable[float]].
+ :type points: Iterable[Iterable[float]]
Or
@@ -86,7 +87,7 @@ def pick_n_random_points(points=None, off_file='', nb_points=0):
And in both cases
:param nb_points: Number of points of the subsample.
- :type nb_points: unsigned.
+ :type nb_points: int
:returns: The subsample point set.
:rtype: List[List[float]]
"""
@@ -107,7 +108,7 @@ def sparsify_point_set(points=None, off_file='', min_squared_dist=0.0):
between any two points is greater than or equal to min_squared_dist.
:param points: The input point set.
- :type points: Iterable[Iterable[float]].
+ :type points: Iterable[Iterable[float]]
Or
@@ -118,7 +119,7 @@ def sparsify_point_set(points=None, off_file='', min_squared_dist=0.0):
:param min_squared_dist: Minimum squared distance separating the output \
points.
- :type min_squared_dist: float.
+ :type min_squared_dist: float
:returns: The subsample point set.
:rtype: List[List[float]]
"""
diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h
index e288a8cf..baff3850 100644
--- a/src/python/include/Simplex_tree_interface.h
+++ b/src/python/include/Simplex_tree_interface.h
@@ -39,6 +39,7 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
using Skeleton_simplex_iterator = typename Base::Skeleton_simplex_iterator;
using Complex_simplex_iterator = typename Base::Complex_simplex_iterator;
using Extended_filtration_data = typename Base::Extended_filtration_data;
+ using Boundary_simplex_iterator = typename Base::Boundary_simplex_iterator;
public:
@@ -219,6 +220,15 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
// this specific case works because the range is just a pair of iterators - won't work if range was a vector
return Base::skeleton_simplex_range(dimension).end();
}
+
+ std::pair<Boundary_simplex_iterator, Boundary_simplex_iterator> get_boundary_iterators(const Simplex& simplex) {
+ auto bd_sh = Base::find(simplex);
+ if (bd_sh == Base::null_simplex())
+ throw std::runtime_error("simplex not found - cannot find boundaries");
+ // this specific case works because the range is just a pair of iterators - won't work if range was a vector
+ auto boundary_srange = Base::boundary_simplex_range(bd_sh);
+ return std::make_pair(boundary_srange.begin(), boundary_srange.end());
+ }
};
} // namespace Gudhi
diff --git a/src/python/test/test_bottleneck_distance.py b/src/python/test/test_bottleneck_distance.py
index 6915bea8..07fcc9cc 100755
--- a/src/python/test/test_bottleneck_distance.py
+++ b/src/python/test/test_bottleneck_distance.py
@@ -25,3 +25,15 @@ def test_basic_bottleneck():
assert gudhi.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, abs=0.1)
assert gudhi.hera.bottleneck_distance(diag1, diag2, 0) == 0.75
assert gudhi.hera.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, rel=0.1)
+
+ import numpy as np
+
+ # Translating both diagrams along the diagonal should not affect the distance,
+ # checks that negative numbers are not an issue
+ diag1 = np.array(diag1) - 100
+ diag2 = np.array(diag2) - 100
+
+ assert gudhi.bottleneck_distance(diag1, diag2) == 0.75
+ assert gudhi.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, abs=0.1)
+ assert gudhi.hera.bottleneck_distance(diag1, diag2, 0) == 0.75
+ assert gudhi.hera.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, rel=0.1)
diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py
index e5c211a0..43c914f3 100755
--- a/src/python/test/test_representations.py
+++ b/src/python/test/test_representations.py
@@ -39,11 +39,11 @@ def test_multiple():
d2 = BottleneckDistance(epsilon=0.00001).fit_transform(l1)
d3 = pairwise_persistence_diagram_distances(l1, l1b, e=0.00001, n_jobs=4)
assert d1 == pytest.approx(d2)
- assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
+ assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal)
d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2)
d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1)
print(d1.shape, d2.shape)
- assert d1 == pytest.approx(d2, rel=.02)
+ assert d1 == pytest.approx(d2, rel=0.02)
def test_dummy_atol():
@@ -53,8 +53,22 @@ def test_dummy_atol():
for weighting_method in ["cloud", "iidproba"]:
for contrast in ["gaussian", "laplacian", "indicator"]:
- atol_vectoriser = Atol(quantiser=KMeans(n_clusters=1, random_state=202006), weighting_method=weighting_method, contrast=contrast)
+ atol_vectoriser = Atol(
+ quantiser=KMeans(n_clusters=1, random_state=202006),
+ weighting_method=weighting_method,
+ contrast=contrast,
+ )
atol_vectoriser.fit([a, b, c])
atol_vectoriser(a)
atol_vectoriser.transform(X=[a, b, c])
+
+from gudhi.representations.vector_methods import BettiCurve
+
+
+def test_infinity():
+ a = np.array([[1.0, 8.0], [2.0, np.inf], [3.0, 4.0]])
+ c = BettiCurve(20, [0.0, 10.0])(a)
+ assert c[1] == 0
+ assert c[7] == 3
+ assert c[9] == 2
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index 83be0602..3b23fa0b 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -356,5 +356,46 @@ def test_collapse_edges():
st.collapse_edges()
assert st.num_simplices() == 9
assert st.find([1, 3]) == False
- for simplex in st.get_skeleton(0):
- assert simplex[1] == 1.
+ for simplex in st.get_skeleton(0):
+ assert simplex[1] == 1.
+
+def test_reset_filtration():
+ st = SimplexTree()
+
+ assert st.insert([0, 1, 2], 3.) == True
+ assert st.insert([0, 3], 2.) == True
+ assert st.insert([3, 4, 5], 3.) == True
+ assert st.insert([0, 1, 6, 7], 4.) == True
+
+ # Guaranteed by construction
+ for simplex in st.get_simplices():
+ assert st.filtration(simplex[0]) >= 2.
+
+ # dimension until 5 even if simplex tree is of dimension 3 to test the limits
+ for dimension in range(5, -1, -1):
+ st.reset_filtration(0., dimension)
+ for simplex in st.get_skeleton(3):
+ print(simplex)
+ if len(simplex[0]) < (dimension) + 1:
+ assert st.filtration(simplex[0]) >= 2.
+ else:
+ assert st.filtration(simplex[0]) == 0.
+
+def test_boundaries_iterator():
+ st = SimplexTree()
+
+ assert st.insert([0, 1, 2, 3], filtration=1.0) == True
+ assert st.insert([1, 2, 3, 4], filtration=2.0) == True
+
+ assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)]
+ assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)]
+ assert list(st.get_boundaries([2])) == []
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([]))
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([0, 4])) # (0, 4) does not exist
+
+ with pytest.raises(RuntimeError):
+ list(st.get_boundaries([6])) # (6) does not exist
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 90d26809..e3b521d6 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -97,27 +97,3 @@ def test_wasserstein_distance_pot():
def test_wasserstein_distance_hera():
_basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
_basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
-
-def test_wasserstein_distance_grad():
- import torch
-
- diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
- diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
- diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
- assert diag1.grad is None and diag2.grad is None and diag3.grad is None
- dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
- dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
- dist12.backward()
- dist30.backward()
- assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
- diag4 = torch.tensor([[0., 10.]], requires_grad=True)
- diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
- dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
- assert dist45 == 3.
- dist45.backward()
- assert np.array_equal(diag4.grad, [[-1., -1.]])
- assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
- diag6 = torch.tensor([[5., 10.]], requires_grad=True)
- pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
- # https://github.com/jonasrauber/eagerpy/issues/6
- # assert np.array_equal(diag6.grad, [[0., 0.]])
diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py
new file mode 100755
index 00000000..e3f1411a
--- /dev/null
+++ b/src/python/test/test_wasserstein_with_tensors.py
@@ -0,0 +1,47 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Mathieu Carriere
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.wasserstein import wasserstein_distance as pot
+import numpy as np
+import torch
+import tensorflow as tf
+
+def test_wasserstein_distance_grad():
+ diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
+ diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ assert diag1.grad is None and diag2.grad is None and diag3.grad is None
+ dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
+ dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
+ dist12.backward()
+ dist30.backward()
+ assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
+ diag4 = torch.tensor([[0., 10.]], requires_grad=True)
+ diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+ dist45.backward()
+ assert np.array_equal(diag4.grad, [[-1., -1.]])
+ assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
+ diag6 = torch.tensor([[5., 10.]], requires_grad=True)
+ pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ # assert np.array_equal(diag6.grad, [[0., 0.]])
+
+def test_wasserstein_distance_grad_tensorflow():
+ with tf.GradientTape() as tape:
+ diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True))
+ diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True))
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+
+ grads = tape.gradient(dist45, [diag4, diag5])
+ assert np.array_equal(grads[0].values, [[-1., -1.]])
+ assert np.array_equal(grads[1].values, [[1., 1.], [-1., 1.]]) \ No newline at end of file