From 88a36ffad6c11279990c1c96df32b95c1f6f526c Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Fri, 3 Jul 2020 13:57:49 +0200 Subject: A fix proposal for boudaries of a simplex python version --- .../include/gudhi/Simplex_tree/Simplex_tree_iterators.h | 5 +++++ src/python/gudhi/simplex_tree.pxd | 8 ++++++++ src/python/gudhi/simplex_tree.pyx | 16 ++++++++++++++++ src/python/include/Simplex_tree_interface.h | 11 +++++++++++ src/python/test/test_simplex_tree.py | 9 +++++++++ 5 files changed, 49 insertions(+) diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h index 9007b6bd..9c864454 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h @@ -85,6 +85,11 @@ class Simplex_tree_boundary_simplex_iterator : public boost::iterator_facade< typedef typename SimplexTree::Vertex_handle Vertex_handle; typedef typename SimplexTree::Siblings Siblings; + Simplex_tree_boundary_simplex_iterator() + : sib_(nullptr), + st_(nullptr) { + } + // any end() iterator explicit Simplex_tree_boundary_simplex_iterator(SimplexTree * st) : last_(st->null_vertex()), diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index e748ac40..a64599e7 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -36,6 +36,12 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": Simplex_tree_skeleton_iterator operator++() nogil bint operator!=(Simplex_tree_skeleton_iterator) nogil + cdef cppclass Simplex_tree_boundary_iterator "Gudhi::Simplex_tree_interface::Boundary_simplex_iterator": + Simplex_tree_boundary_iterator() nogil + Simplex_tree_simplex_handle& operator*() nogil + Simplex_tree_boundary_iterator operator++() nogil + bint operator!=(Simplex_tree_boundary_iterator) nogil + cdef cppclass Simplex_tree_interface_full_featured "Gudhi::Simplex_tree_interface": Simplex_tree() nogil @@ -65,6 +71,8 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": vector[Simplex_tree_simplex_handle].const_iterator get_filtration_iterator_end() nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil + Simplex_tree_boundary_iterator get_boundary_iterator_begin(vector[int] simplex) nogil + Simplex_tree_boundary_iterator get_boundary_iterator_end(vector[int] simplex) nogil cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface>": diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 20e66d9f..da445a12 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -285,6 +285,22 @@ cdef class SimplexTree: ct.append((v, filtered_simplex.second)) return ct + def get_boundaries(self, simplex): + """This function returns a generator with the boundaries of a given N-simplex. + + :param simplex: The N-simplex, represented by a list of vertex. + :type simplex: list of int. + :returns: The (simplices of the) boundaries of a simplex + :rtype: generator with tuples(simplex, filtration) + """ + cdef Simplex_tree_boundary_iterator it = self.get_ptr().get_boundary_iterator_begin(simplex) + cdef Simplex_tree_boundary_iterator end = self.get_ptr().get_boundary_iterator_end(simplex) + + while it != end: + yield self.get_ptr().get_simplex_and_filtration(dereference(it)) + preincrement(it) + + def remove_maximal_simplex(self, simplex): """This function removes a given maximal N-simplex from the simplicial complex. diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index 56d7c41d..c4f18eeb 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -36,6 +36,7 @@ class Simplex_tree_interface : public Simplex_tree { using Skeleton_simplex_iterator = typename Base::Skeleton_simplex_iterator; using Complex_simplex_iterator = typename Base::Complex_simplex_iterator; using Extended_filtration_data = typename Base::Extended_filtration_data; + using Boundary_simplex_iterator = typename Base::Boundary_simplex_iterator; public: @@ -188,6 +189,16 @@ class Simplex_tree_interface : public Simplex_tree { // this specific case works because the range is just a pair of iterators - won't work if range was a vector return Base::skeleton_simplex_range(dimension).end(); } + + Boundary_simplex_iterator get_boundary_iterator_begin(const Simplex& simplex) { + // this specific case works because the range is just a pair of iterators - won't work if range was a vector + return Base::boundary_simplex_range(Base::find(simplex)).begin(); + } + + Boundary_simplex_iterator get_boundary_iterator_end(const Simplex& simplex) { + // this specific case works because the range is just a pair of iterators - won't work if range was a vector + return Base::boundary_simplex_range(Base::find(simplex)).end(); + } }; } // namespace Gudhi diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 2137d822..7c49cb1a 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -340,3 +340,12 @@ def test_simplices_iterator(): assert st.find(simplex[0]) == True print("filtration is: ", simplex[1]) assert st.filtration(simplex[0]) == simplex[1] + +def test_boundaries_iterator(): + st = SimplexTree() + + assert st.insert([0, 1, 2, 3], filtration=1.0) == True + assert st.insert([1, 2, 3, 4], filtration=2.0) == True + + assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)] + assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)] -- cgit v1.2.3 From 77cda62efd84ca7090d8796ca0197b03dbc22350 Mon Sep 17 00:00:00 2001 From: Vincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com> Date: Wed, 12 Aug 2020 07:28:31 +0200 Subject: Update src/python/gudhi/simplex_tree.pyx Co-authored-by: Marc Glisse --- src/python/gudhi/simplex_tree.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index da445a12..3ebae923 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -290,7 +290,7 @@ cdef class SimplexTree: :param simplex: The N-simplex, represented by a list of vertex. :type simplex: list of int. - :returns: The (simplices of the) boundaries of a simplex + :returns: The (simplices of the) boundary of a simplex :rtype: generator with tuples(simplex, filtration) """ cdef Simplex_tree_boundary_iterator it = self.get_ptr().get_boundary_iterator_begin(simplex) -- cgit v1.2.3 From 1f31e5ce995b0e7683c4eab7077317e28ee5b1b3 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Wed, 12 Aug 2020 08:02:00 +0200 Subject: code review: add a comment about Simplex_tree_boundary_simplex_iterator default ctor --- src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h index 9c864454..ee64a277 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h @@ -85,6 +85,7 @@ class Simplex_tree_boundary_simplex_iterator : public boost::iterator_facade< typedef typename SimplexTree::Vertex_handle Vertex_handle; typedef typename SimplexTree::Siblings Siblings; + // For cython purpose only. The object it initializes should be overwritten ASAP and never used before it is overwritten. Simplex_tree_boundary_simplex_iterator() : sib_(nullptr), st_(nullptr) { -- cgit v1.2.3 From 8f9c065df7f4629555ef09292c14c293891f1bdc Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Wed, 12 Aug 2020 10:03:26 +0200 Subject: code review: Add test to get boundaries from a vertex --- src/python/test/test_simplex_tree.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 7c49cb1a..036e88df 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -349,3 +349,4 @@ def test_boundaries_iterator(): assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)] assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)] + assert list(st.get_boundaries([2])) == [] -- cgit v1.2.3 From 458bc2dcf5044e1d5fde5326b2be35e526abd457 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Wed, 12 Aug 2020 13:06:03 +0200 Subject: code review: boundaries uses only once find and return a pair of iterator. Exception is raised when not found. tested --- src/python/gudhi/simplex_tree.pxd | 3 +-- src/python/gudhi/simplex_tree.pyx | 14 ++++++++------ src/python/include/Simplex_tree_interface.h | 12 +++++------- src/python/test/test_simplex_tree.py | 9 +++++++++ 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index a64599e7..e1b03391 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -71,8 +71,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": vector[Simplex_tree_simplex_handle].const_iterator get_filtration_iterator_end() nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil - Simplex_tree_boundary_iterator get_boundary_iterator_begin(vector[int] simplex) nogil - Simplex_tree_boundary_iterator get_boundary_iterator_end(vector[int] simplex) nogil + pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] get_boundary_iterators(vector[int] simplex) nogil except + cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface>": diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 3ebae923..bc5b43f4 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -293,12 +293,14 @@ cdef class SimplexTree: :returns: The (simplices of the) boundary of a simplex :rtype: generator with tuples(simplex, filtration) """ - cdef Simplex_tree_boundary_iterator it = self.get_ptr().get_boundary_iterator_begin(simplex) - cdef Simplex_tree_boundary_iterator end = self.get_ptr().get_boundary_iterator_end(simplex) - - while it != end: - yield self.get_ptr().get_simplex_and_filtration(dereference(it)) - preincrement(it) + cdef pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] it = self.get_ptr().get_boundary_iterators(simplex) + + try: + while it.first != it.second: + yield self.get_ptr().get_simplex_and_filtration(dereference(it.first)) + preincrement(it.first) + except RuntimeError: + print("simplex not found - cannot find boundaries") def remove_maximal_simplex(self, simplex): diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index c4f18eeb..2c695d1b 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -190,14 +190,12 @@ class Simplex_tree_interface : public Simplex_tree { return Base::skeleton_simplex_range(dimension).end(); } - Boundary_simplex_iterator get_boundary_iterator_begin(const Simplex& simplex) { + std::pair get_boundary_iterators(const Simplex& simplex) { + auto bd_sh = Base::find(simplex); + if (bd_sh == Base::null_simplex()) + throw std::runtime_error("simplex not found - cannot find boundaries"); // this specific case works because the range is just a pair of iterators - won't work if range was a vector - return Base::boundary_simplex_range(Base::find(simplex)).begin(); - } - - Boundary_simplex_iterator get_boundary_iterator_end(const Simplex& simplex) { - // this specific case works because the range is just a pair of iterators - won't work if range was a vector - return Base::boundary_simplex_range(Base::find(simplex)).end(); + return std::make_pair(Base::boundary_simplex_range(bd_sh).begin(), Base::boundary_simplex_range(bd_sh).end()); } }; diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 036e88df..828400fb 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -350,3 +350,12 @@ def test_boundaries_iterator(): assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)] assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)] assert list(st.get_boundaries([2])) == [] + + with pytest.raises(RuntimeError): + list(st.get_boundaries([])) + + with pytest.raises(RuntimeError): + list(st.get_boundaries([0, 4])) # (0, 4) does not exist + + with pytest.raises(RuntimeError): + list(st.get_boundaries([6])) # (6) does not exist -- cgit v1.2.3 From d9ae2cb822631d68e4ef1cec7eed0124080c0020 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Mon, 24 Aug 2020 09:17:23 +0200 Subject: code review: call boundary_simplex_range only once --- src/python/include/Simplex_tree_interface.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index ab3f13f1..baff3850 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -226,7 +226,8 @@ class Simplex_tree_interface : public Simplex_tree { if (bd_sh == Base::null_simplex()) throw std::runtime_error("simplex not found - cannot find boundaries"); // this specific case works because the range is just a pair of iterators - won't work if range was a vector - return std::make_pair(Base::boundary_simplex_range(bd_sh).begin(), Base::boundary_simplex_range(bd_sh).end()); + auto boundary_srange = Base::boundary_simplex_range(bd_sh); + return std::make_pair(boundary_srange.begin(), boundary_srange.end()); } }; -- cgit v1.2.3 From 4186971033ee43821905cac53791bf074751d3af Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Mon, 28 Sep 2020 09:15:52 +0200 Subject: code review: Let the exception propagate --- src/python/gudhi/simplex_tree.pyx | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 5250937b..5043c621 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -295,13 +295,9 @@ cdef class SimplexTree: """ cdef pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] it = self.get_ptr().get_boundary_iterators(simplex) - try: - while it.first != it.second: - yield self.get_ptr().get_simplex_and_filtration(dereference(it.first)) - preincrement(it.first) - except RuntimeError: - print("simplex not found - cannot find boundaries") - + while it.first != it.second: + yield self.get_ptr().get_simplex_and_filtration(dereference(it.first)) + preincrement(it.first) def remove_maximal_simplex(self, simplex): """This function removes a given maximal N-simplex from the simplicial -- cgit v1.2.3 From e7b7947adf13ec1dcb8c126a4373fa29baaecb63 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Tue, 29 Sep 2020 13:23:56 +0200 Subject: Added tests for wasserstein distance with tensorflow --- .../modules/GUDHI_third_party_libraries.cmake | 1 + src/python/CMakeLists.txt | 8 +++++++ src/python/doc/installation.rst | 5 +++++ src/python/test/test_wasserstein_with_tensors.py | 25 ++++++++++++++++++++++ 4 files changed, 39 insertions(+) create mode 100755 src/python/test/test_wasserstein_with_tensors.py diff --git a/src/cmake/modules/GUDHI_third_party_libraries.cmake b/src/cmake/modules/GUDHI_third_party_libraries.cmake index 1fbc4244..9dadac4f 100644 --- a/src/cmake/modules/GUDHI_third_party_libraries.cmake +++ b/src/cmake/modules/GUDHI_third_party_libraries.cmake @@ -155,6 +155,7 @@ if( PYTHONINTERP_FOUND ) find_python_module("pykeops") find_python_module("eagerpy") find_python_module_no_version("hnswlib") + find_python_module("tensorflow") endif() if(NOT GUDHI_PYTHON_PATH) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 4f26481e..cc71503f 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -103,6 +103,9 @@ if(PYTHONINTERP_FOUND) if(EAGERPY_FOUND) add_gudhi_debug_info("EagerPy version ${EAGERPY_VERSION}") endif() + if(TENSORFLOW_FOUND) + add_gudhi_debug_info("TensorFlow version ${TENSORFLOW_VERSION}") + endif() set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_RESULT_OF_USE_DECLTYPE', ") set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_ALL_NO_LIB', ") @@ -501,6 +504,11 @@ if(PYTHONINTERP_FOUND) endif() add_gudhi_py_test(test_wasserstein_barycenter) endif() + if(OT_FOUND) + if(TENSORFLOW_FOUND AND EAGERPY_FOUND) + add_gudhi_py_test(test_wasserstein_with_tensors) + endif() + endif() # Representations if(SKLEARN_FOUND AND MATPLOTLIB_FOUND) diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst index 2f161d66..66efe45a 100644 --- a/src/python/doc/installation.rst +++ b/src/python/doc/installation.rst @@ -394,6 +394,11 @@ mathematics, science, and engineering. :class:`~gudhi.point_cloud.knn.KNearestNeighbors` can use the Python package `SciPy `_ as a backend if explicitly requested. +TensorFlow +---------- + +`TensorFlow `_ is currently only used in some automatic differentiation tests. + Bug reports and contributions ***************************** diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py new file mode 100755 index 00000000..8957705d --- /dev/null +++ b/src/python/test/test_wasserstein_with_tensors.py @@ -0,0 +1,25 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Mathieu Carriere + + Copyright (C) 2020 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.wasserstein import wasserstein_distance as pot +import numpy as np + +def test_wasserstein_distance_grad_tensorflow(): + import tensorflow as tf + + with tf.GradientTape() as tape: + diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True)) + diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True)) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + assert dist45 == 3. + + grads = tape.gradient(dist45, [diag4, diag5]) + assert np.array_equal(grads[0].values, [[-1., -1.]]) + assert np.array_equal(grads[1].values, [[1., 1.], [-1., 1.]]) \ No newline at end of file -- cgit v1.2.3 From f0beb329f5a1767e4e0a0575ef3e078bf4563a44 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Mon, 5 Oct 2020 11:12:44 +0200 Subject: code review: move test_wasserstein_distance_grad from test_wasserstein_distance.py to test_wasserstein_with_tensors.py --- src/python/CMakeLists.txt | 5 +++-- src/python/test/test_wasserstein_distance.py | 24 ---------------------- src/python/test/test_wasserstein_with_tensors.py | 26 ++++++++++++++++++++++-- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index cc71503f..c09996fe 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -499,13 +499,14 @@ if(PYTHONINTERP_FOUND) # Wasserstein if(OT_FOUND AND PYBIND11_FOUND) - if(TORCH_FOUND AND EAGERPY_FOUND) + # EagerPy dependency because of enable_autodiff=True + if(EAGERPY_FOUND) add_gudhi_py_test(test_wasserstein_distance) endif() add_gudhi_py_test(test_wasserstein_barycenter) endif() if(OT_FOUND) - if(TENSORFLOW_FOUND AND EAGERPY_FOUND) + if(TORCH_FOUND AND TENSORFLOW_FOUND AND EAGERPY_FOUND) add_gudhi_py_test(test_wasserstein_with_tensors) endif() endif() diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 90d26809..e3b521d6 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -97,27 +97,3 @@ def test_wasserstein_distance_pot(): def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) - -def test_wasserstein_distance_grad(): - import torch - - diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) - diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) - diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) - assert diag1.grad is None and diag2.grad is None and diag3.grad is None - dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) - dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) - dist12.backward() - dist30.backward() - assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() - diag4 = torch.tensor([[0., 10.]], requires_grad=True) - diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) - dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) - assert dist45 == 3. - dist45.backward() - assert np.array_equal(diag4.grad, [[-1., -1.]]) - assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) - diag6 = torch.tensor([[5., 10.]], requires_grad=True) - pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() - # https://github.com/jonasrauber/eagerpy/issues/6 - # assert np.array_equal(diag6.grad, [[0., 0.]]) diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py index 8957705d..e3f1411a 100755 --- a/src/python/test/test_wasserstein_with_tensors.py +++ b/src/python/test/test_wasserstein_with_tensors.py @@ -10,10 +10,32 @@ from gudhi.wasserstein import wasserstein_distance as pot import numpy as np +import torch +import tensorflow as tf -def test_wasserstein_distance_grad_tensorflow(): - import tensorflow as tf +def test_wasserstein_distance_grad(): + diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) + diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + assert diag1.grad is None and diag2.grad is None and diag3.grad is None + dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) + dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist12.backward() + dist30.backward() + assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() + diag4 = torch.tensor([[0., 10.]], requires_grad=True) + diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + assert dist45 == 3. + dist45.backward() + assert np.array_equal(diag4.grad, [[-1., -1.]]) + assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) + diag6 = torch.tensor([[5., 10.]], requires_grad=True) + pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() + # https://github.com/jonasrauber/eagerpy/issues/6 + # assert np.array_equal(diag6.grad, [[0., 0.]]) +def test_wasserstein_distance_grad_tensorflow(): with tf.GradientTape() as tape: diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True)) diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True)) -- cgit v1.2.3 From c66b9126429e1ff18f9ca69b27c5f357f071a697 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 19 Oct 2020 01:05:04 +0200 Subject: Handle duplicated points --- src/Subsampling/include/gudhi/choose_n_farthest_points.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Subsampling/include/gudhi/choose_n_farthest_points.h b/src/Subsampling/include/gudhi/choose_n_farthest_points.h index 66421a69..38c3a76b 100644 --- a/src/Subsampling/include/gudhi/choose_n_farthest_points.h +++ b/src/Subsampling/include/gudhi/choose_n_farthest_points.h @@ -111,6 +111,8 @@ void choose_n_farthest_points(Kernel const &k, curr_max_dist = dist_to_L[i]; curr_max_w = i; } + // If all that remains are duplicates of points already taken, stop. + if (curr_max_dist == 0) break; } } -- cgit v1.2.3 From dda7885005c343601c6630796eb56bdcf91a559f Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Thu, 22 Oct 2020 22:23:28 +0200 Subject: Document the change It would be possible to emit the duplicate points instead of stopping, but the current implementation makes that inconvenient. --- .../include/gudhi/choose_n_farthest_points.h | 3 ++- src/python/gudhi/subsampling.pyx | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/Subsampling/include/gudhi/choose_n_farthest_points.h b/src/Subsampling/include/gudhi/choose_n_farthest_points.h index 38c3a76b..0e13fc5a 100644 --- a/src/Subsampling/include/gudhi/choose_n_farthest_points.h +++ b/src/Subsampling/include/gudhi/choose_n_farthest_points.h @@ -48,7 +48,8 @@ enum : std::size_t { * \tparam PointOutputIterator Output iterator whose value type is Kernel::Point_d. * \tparam DistanceOutputIterator Output iterator for distances. * \details It chooses `final_size` points from a random access range - * `input_pts` and outputs them in the output iterator `output_it`. It also + * `input_pts` (or the number of distinct points if `final_size` is larger) + * and outputs them in the output iterator `output_it`. It also * outputs the distance from each of those points to the set of previous * points in `dist_it`. * @param[in] k A kernel object. diff --git a/src/python/gudhi/subsampling.pyx b/src/python/gudhi/subsampling.pyx index f77c6f75..b11d07e5 100644 --- a/src/python/gudhi/subsampling.pyx +++ b/src/python/gudhi/subsampling.pyx @@ -33,7 +33,7 @@ def choose_n_farthest_points(points=None, off_file='', nb_points=0, starting_poi The iteration starts with the landmark `starting point`. :param points: The input point set. - :type points: Iterable[Iterable[float]]. + :type points: Iterable[Iterable[float]] Or @@ -42,14 +42,15 @@ def choose_n_farthest_points(points=None, off_file='', nb_points=0, starting_poi And in both cases - :param nb_points: Number of points of the subsample. - :type nb_points: unsigned. + :param nb_points: Number of points of the subsample (the subsample may be \ + smaller if there are fewer than nb_points distinct input points) + :type nb_points: int :param starting_point: The iteration starts with the landmark `starting \ - point`,which is the index of the point to start with. If not set, this \ + point`, which is the index of the point to start with. If not set, this \ index is chosen randomly. - :type starting_point: unsigned. + :type starting_point: int :returns: The subsample point set. - :rtype: List[List[float]]. + :rtype: List[List[float]] """ if off_file: if os.path.isfile(off_file): @@ -76,7 +77,7 @@ def pick_n_random_points(points=None, off_file='', nb_points=0): """Subsample a point set by picking random vertices. :param points: The input point set. - :type points: Iterable[Iterable[float]]. + :type points: Iterable[Iterable[float]] Or @@ -86,7 +87,7 @@ def pick_n_random_points(points=None, off_file='', nb_points=0): And in both cases :param nb_points: Number of points of the subsample. - :type nb_points: unsigned. + :type nb_points: int :returns: The subsample point set. :rtype: List[List[float]] """ @@ -107,7 +108,7 @@ def sparsify_point_set(points=None, off_file='', min_squared_dist=0.0): between any two points is greater than or equal to min_squared_dist. :param points: The input point set. - :type points: Iterable[Iterable[float]]. + :type points: Iterable[Iterable[float]] Or @@ -118,7 +119,7 @@ def sparsify_point_set(points=None, off_file='', min_squared_dist=0.0): :param min_squared_dist: Minimum squared distance separating the output \ points. - :type min_squared_dist: float. + :type min_squared_dist: float :returns: The subsample point set. :rtype: List[List[float]] """ -- cgit v1.2.3 From 8aea376ed0b3c9066fb7e649f1cd66ffbed99a8d Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Thu, 22 Oct 2020 22:27:10 +0200 Subject: Simplify strange iterator use the syntax with [] is already used a few lines above --- src/Subsampling/include/gudhi/choose_n_farthest_points.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Subsampling/include/gudhi/choose_n_farthest_points.h b/src/Subsampling/include/gudhi/choose_n_farthest_points.h index 0e13fc5a..b70af8a0 100644 --- a/src/Subsampling/include/gudhi/choose_n_farthest_points.h +++ b/src/Subsampling/include/gudhi/choose_n_farthest_points.h @@ -100,7 +100,7 @@ void choose_n_farthest_points(Kernel const &k, *dist_it++ = dist_to_L[curr_max_w]; std::size_t i = 0; for (auto&& p : input_pts) { - double curr_dist = sqdist(p, *(std::begin(input_pts) + curr_max_w)); + double curr_dist = sqdist(p, input_pts[curr_max_w]); if (curr_dist < dist_to_L[i]) dist_to_L[i] = curr_dist; ++i; -- cgit v1.2.3 From 3ae0bc89f6ef853c1c52fc609bfe08097d3594db Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Thu, 22 Oct 2020 22:49:23 +0200 Subject: test with duplicated point --- src/Subsampling/test/test_choose_n_farthest_points.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Subsampling/test/test_choose_n_farthest_points.cpp b/src/Subsampling/test/test_choose_n_farthest_points.cpp index 5c4bd4cb..08b82d61 100644 --- a/src/Subsampling/test/test_choose_n_farthest_points.cpp +++ b/src/Subsampling/test/test_choose_n_farthest_points.cpp @@ -93,7 +93,15 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(test_choose_farthest_point_limits, Kernel, list_of std::vector point2({1.0, 0.0, 0.0, 0.0}); points.push_back(Point_d(point2.begin(), point2.end())); - // Choose all farthest points in a one point cloud + // Choose all farthest points among 2 points + Gudhi::subsampling::choose_n_farthest_points(k, points, -1, -1, std::back_inserter(landmarks), std::back_inserter(distances)); + BOOST_CHECK(landmarks.size() == 2 && distances.size() == 2); + BOOST_CHECK(distances[0] == std::numeric_limits::infinity()); + BOOST_CHECK(distances[1] == 1); + landmarks.clear(); distances.clear(); + + // Ignore duplicated points + points.push_back(Point_d(point.begin(), point.end())); Gudhi::subsampling::choose_n_farthest_points(k, points, -1, -1, std::back_inserter(landmarks), std::back_inserter(distances)); BOOST_CHECK(landmarks.size() == 2 && distances.size() == 2); BOOST_CHECK(distances[0] == std::numeric_limits::infinity()); -- cgit v1.2.3 From 705aa3b7ddc0a2bbbcc31c4b45e19792bd4ce9a5 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Wed, 28 Oct 2020 00:04:37 +0100 Subject: emplace_back --- src/Subsampling/test/test_choose_n_farthest_points.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Subsampling/test/test_choose_n_farthest_points.cpp b/src/Subsampling/test/test_choose_n_farthest_points.cpp index 08b82d61..b318d58e 100644 --- a/src/Subsampling/test/test_choose_n_farthest_points.cpp +++ b/src/Subsampling/test/test_choose_n_farthest_points.cpp @@ -39,7 +39,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(test_choose_farthest_point, Kernel, list_of_tested for (FT k = 0; k < 5; k += 1.0) for (FT l = 0; l < 5; l += 1.0) { std::vector point({i, j, k, l}); - points.push_back(Point_d(point.begin(), point.end())); + points.emplace_back(point.begin(), point.end()); } landmarks.clear(); @@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(test_choose_farthest_point_limits, Kernel, list_of landmarks.clear(); distances.clear(); std::vector point({0.0, 0.0, 0.0, 0.0}); - points.push_back(Point_d(point.begin(), point.end())); + points.emplace_back(point.begin(), point.end()); // Choose -1 farthest points in a one point cloud Gudhi::subsampling::choose_n_farthest_points(k, points, -1, -1, std::back_inserter(landmarks), std::back_inserter(distances)); BOOST_CHECK(landmarks.size() == 1 && distances.size() == 1); @@ -92,7 +92,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(test_choose_farthest_point_limits, Kernel, list_of landmarks.clear(); distances.clear(); std::vector point2({1.0, 0.0, 0.0, 0.0}); - points.push_back(Point_d(point2.begin(), point2.end())); + points.emplace_back(point2.begin(), point2.end()); // Choose all farthest points among 2 points Gudhi::subsampling::choose_n_farthest_points(k, points, -1, -1, std::back_inserter(landmarks), std::back_inserter(distances)); BOOST_CHECK(landmarks.size() == 2 && distances.size() == 2); @@ -101,7 +101,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(test_choose_farthest_point_limits, Kernel, list_of landmarks.clear(); distances.clear(); // Ignore duplicated points - points.push_back(Point_d(point.begin(), point.end())); + points.emplace_back(point.begin(), point.end()); Gudhi::subsampling::choose_n_farthest_points(k, points, -1, -1, std::back_inserter(landmarks), std::back_inserter(distances)); BOOST_CHECK(landmarks.size() == 2 && distances.size() == 2); BOOST_CHECK(distances[0] == std::numeric_limits::infinity()); -- cgit v1.2.3