From af49fdd761bf1eccb5fdca760a99e2e250895f64 Mon Sep 17 00:00:00 2001 From: martinroyer-buntu Date: Fri, 3 Jul 2020 10:58:54 +0200 Subject: dummy test for code coverage --- src/python/test/test_representations.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'src/python/test') diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 589cee00..6a09be48 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -4,6 +4,8 @@ import matplotlib.pyplot as plt import numpy as np import pytest +from sklearn.cluster import KMeans + def test_representations_examples(): # Disable graphics for testing purposes @@ -15,6 +17,7 @@ def test_representations_examples(): return None +from gudhi.representations.vector_methods import Atol from gudhi.representations.metrics import * from gudhi.representations.kernel_methods import * @@ -41,3 +44,16 @@ def test_multiple(): d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1) print(d1.shape, d2.shape) assert d1 == pytest.approx(d2, rel=.02) + + +def test_dummy_atol(): + a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]]) + b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]]) + c = np.array([[3, 2, -1], [1, 2, -1]]) + + for weighting_method in ["cloud", "iidproba"]: + for contrast in ["gaussian", "laplacian", "indicator"]: + atol_vectoriser = Atol(quantiser=KMeans(n_clusters=1, random_state=202006), weighting_method=weighting_method, contrast=contrast) + atol_vectoriser(a) + atol_vectoriser.transform(X=[a, b, c]) + -- cgit v1.2.3 From 96eb09e4f034fd71f5674f75e5e4584a7402b218 Mon Sep 17 00:00:00 2001 From: martinroyer-buntu Date: Fri, 3 Jul 2020 13:43:58 +0200 Subject: missing fit in dummy test --- src/python/test/test_representations.py | 1 + 1 file changed, 1 insertion(+) (limited to 'src/python/test') diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 6a09be48..e5c211a0 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -54,6 +54,7 @@ def test_dummy_atol(): for weighting_method in ["cloud", "iidproba"]: for contrast in ["gaussian", "laplacian", "indicator"]: atol_vectoriser = Atol(quantiser=KMeans(n_clusters=1, random_state=202006), weighting_method=weighting_method, contrast=contrast) + atol_vectoriser.fit([a, b, c]) atol_vectoriser(a) atol_vectoriser.transform(X=[a, b, c]) -- cgit v1.2.3 From 88a36ffad6c11279990c1c96df32b95c1f6f526c Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Fri, 3 Jul 2020 13:57:49 +0200 Subject: A fix proposal for boudaries of a simplex python version --- .../include/gudhi/Simplex_tree/Simplex_tree_iterators.h | 5 +++++ src/python/gudhi/simplex_tree.pxd | 8 ++++++++ src/python/gudhi/simplex_tree.pyx | 16 ++++++++++++++++ src/python/include/Simplex_tree_interface.h | 11 +++++++++++ src/python/test/test_simplex_tree.py | 9 +++++++++ 5 files changed, 49 insertions(+) (limited to 'src/python/test') diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h index 9007b6bd..9c864454 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h @@ -85,6 +85,11 @@ class Simplex_tree_boundary_simplex_iterator : public boost::iterator_facade< typedef typename SimplexTree::Vertex_handle Vertex_handle; typedef typename SimplexTree::Siblings Siblings; + Simplex_tree_boundary_simplex_iterator() + : sib_(nullptr), + st_(nullptr) { + } + // any end() iterator explicit Simplex_tree_boundary_simplex_iterator(SimplexTree * st) : last_(st->null_vertex()), diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index e748ac40..a64599e7 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -36,6 +36,12 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": Simplex_tree_skeleton_iterator operator++() nogil bint operator!=(Simplex_tree_skeleton_iterator) nogil + cdef cppclass Simplex_tree_boundary_iterator "Gudhi::Simplex_tree_interface::Boundary_simplex_iterator": + Simplex_tree_boundary_iterator() nogil + Simplex_tree_simplex_handle& operator*() nogil + Simplex_tree_boundary_iterator operator++() nogil + bint operator!=(Simplex_tree_boundary_iterator) nogil + cdef cppclass Simplex_tree_interface_full_featured "Gudhi::Simplex_tree_interface": Simplex_tree() nogil @@ -65,6 +71,8 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": vector[Simplex_tree_simplex_handle].const_iterator get_filtration_iterator_end() nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil + Simplex_tree_boundary_iterator get_boundary_iterator_begin(vector[int] simplex) nogil + Simplex_tree_boundary_iterator get_boundary_iterator_end(vector[int] simplex) nogil cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface>": diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 20e66d9f..da445a12 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -285,6 +285,22 @@ cdef class SimplexTree: ct.append((v, filtered_simplex.second)) return ct + def get_boundaries(self, simplex): + """This function returns a generator with the boundaries of a given N-simplex. + + :param simplex: The N-simplex, represented by a list of vertex. + :type simplex: list of int. + :returns: The (simplices of the) boundaries of a simplex + :rtype: generator with tuples(simplex, filtration) + """ + cdef Simplex_tree_boundary_iterator it = self.get_ptr().get_boundary_iterator_begin(simplex) + cdef Simplex_tree_boundary_iterator end = self.get_ptr().get_boundary_iterator_end(simplex) + + while it != end: + yield self.get_ptr().get_simplex_and_filtration(dereference(it)) + preincrement(it) + + def remove_maximal_simplex(self, simplex): """This function removes a given maximal N-simplex from the simplicial complex. diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index 56d7c41d..c4f18eeb 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -36,6 +36,7 @@ class Simplex_tree_interface : public Simplex_tree { using Skeleton_simplex_iterator = typename Base::Skeleton_simplex_iterator; using Complex_simplex_iterator = typename Base::Complex_simplex_iterator; using Extended_filtration_data = typename Base::Extended_filtration_data; + using Boundary_simplex_iterator = typename Base::Boundary_simplex_iterator; public: @@ -188,6 +189,16 @@ class Simplex_tree_interface : public Simplex_tree { // this specific case works because the range is just a pair of iterators - won't work if range was a vector return Base::skeleton_simplex_range(dimension).end(); } + + Boundary_simplex_iterator get_boundary_iterator_begin(const Simplex& simplex) { + // this specific case works because the range is just a pair of iterators - won't work if range was a vector + return Base::boundary_simplex_range(Base::find(simplex)).begin(); + } + + Boundary_simplex_iterator get_boundary_iterator_end(const Simplex& simplex) { + // this specific case works because the range is just a pair of iterators - won't work if range was a vector + return Base::boundary_simplex_range(Base::find(simplex)).end(); + } }; } // namespace Gudhi diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 2137d822..7c49cb1a 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -340,3 +340,12 @@ def test_simplices_iterator(): assert st.find(simplex[0]) == True print("filtration is: ", simplex[1]) assert st.filtration(simplex[0]) == simplex[1] + +def test_boundaries_iterator(): + st = SimplexTree() + + assert st.insert([0, 1, 2, 3], filtration=1.0) == True + assert st.insert([1, 2, 3, 4], filtration=2.0) == True + + assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)] + assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)] -- cgit v1.2.3 From 85eec1ba750d56b66e3739dc486c6205f49fb31e Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Fri, 3 Jul 2020 16:04:45 +0200 Subject: A proposal for simplex_tree reset_filtration (python & C++) --- src/Simplex_tree/include/gudhi/Simplex_tree.h | 30 +++++++++++++++ src/Simplex_tree/test/simplex_tree_unit_test.cpp | 47 ++++++++++++++++++++++++ src/python/gudhi/simplex_tree.pxd | 1 + src/python/gudhi/simplex_tree.pyx | 10 +++++ src/python/test/test_simplex_tree.py | 22 +++++++++++ 5 files changed, 110 insertions(+) (limited to 'src/python/test') diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree.h b/src/Simplex_tree/include/gudhi/Simplex_tree.h index 889dbd00..adc8e801 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree.h @@ -1667,6 +1667,36 @@ class Simplex_tree { return sh; // None of its faces has the same filtration. } + public: + /** \brief This function resets filtration value until a given dimension. + * @param[in] filt_value The new filtration value. + * @param[in] max_dim The maximal dimension. + */ + void reset_filtration(Filtration_value filt_value, int max_dim) { + for (auto& simplex : root_.members()) { + simplex.second.assign_filtration(filt_value); + if (has_children(&simplex) && max_dim > 0) { + rec_reset_filtration(simplex.second.children(), filt_value, (max_dim - 1)); + } + } + clear_filtration(); // Drop the cache. + } + + private: + /** \brief Recursively resets filtration value until a given dimension. + * @param[in] sib Siblings to be parsed. + * @param[in] filt_value The new filtration value. + * @param[in] max_dim The maximal dimension. + */ + void rec_reset_filtration(Siblings * sib, Filtration_value filt_value, int max_dim) { + for (auto& simplex : sib->members()) { + simplex.second.assign_filtration(filt_value); + if (has_children(&simplex) && max_dim > 0) { + rec_reset_filtration(simplex.second.children(), filt_value, (max_dim - 1)); + } + } + } + private: Vertex_handle null_vertex_; /** \brief Total number of simplices in the complex, without the empty simplex.*/ diff --git a/src/Simplex_tree/test/simplex_tree_unit_test.cpp b/src/Simplex_tree/test/simplex_tree_unit_test.cpp index 9b5fa8fe..9780f5b0 100644 --- a/src/Simplex_tree/test/simplex_tree_unit_test.cpp +++ b/src/Simplex_tree/test/simplex_tree_unit_test.cpp @@ -940,3 +940,50 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(generators, typeST, list_of_tested_variants) { BOOST_CHECK(st.edge_with_same_filtration(st.find({1,5}))==st.find({1,5})); } } + +BOOST_AUTO_TEST_CASE_TEMPLATE(simplex_tree_reset_filtration, typeST, list_of_tested_variants) { + std::clog << "********************************************************************" << std::endl; + std::clog << "TEST RESET FILTRATION" << std::endl; + typeST st; + + st.insert_simplex_and_subfaces({2, 1, 0}, 3.); + st.insert_simplex_and_subfaces({3, 0}, 2.); + st.insert_simplex_and_subfaces({3, 4, 5}, 3.); + st.insert_simplex_and_subfaces({0, 1, 6, 7}, 4.); + + /* Inserted simplex: */ + /* 1 6 */ + /* o---o */ + /* /X\7/ */ + /* o---o---o---o */ + /* 2 0 3\X/4 */ + /* o */ + /* 5 */ + + for (auto f_simplex : st.skeleton_simplex_range(3)) { + std::clog << "vertex = ("; + for (auto vertex : st.simplex_vertex_range(f_simplex)) { + std::clog << vertex << ","; + } + std::clog << ") - filtration =" << st.filtration(f_simplex) << std::endl; + BOOST_CHECK(st.filtration(f_simplex) >= 2.); + } + + // dimension until 5 even if simplex tree is of dimension 3 to test the limits + for(int dimension = 0; dimension < 6; dimension ++) { + st.reset_filtration(0., dimension); + for (auto f_simplex : st.skeleton_simplex_range(3)) { + std::clog << "vertex = ("; + for (auto vertex : st.simplex_vertex_range(f_simplex)) { + std::clog << vertex << ","; + } + std::clog << ") - filtration =" << st.filtration(f_simplex) << std::endl; + if (st.dimension(f_simplex) > dimension) + BOOST_CHECK(st.filtration(f_simplex) >= 1.); + else + BOOST_CHECK(st.filtration(f_simplex) == 0.); + } + } + +} + diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index e748ac40..12c2065e 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -57,6 +57,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": bool make_filtration_non_decreasing() nogil void compute_extended_filtration() nogil vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil + void reset_filtration(double filtration, int dimension) nogil # Iterators over Simplex tree pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) nogil Simplex_tree_simplices_iterator get_simplices_iterator_begin() nogil diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 20e66d9f..41b06116 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -358,6 +358,16 @@ cdef class SimplexTree: """ return self.get_ptr().make_filtration_non_decreasing() + def reset_filtration(self, filtration, max_dim): + """This function resets filtration value until a given dimension. + + :param filtration: New threshold value. + :type filtration: float. + :param max_dim: The maximal dimension. + :type max_dim: int. + """ + self.get_ptr().reset_filtration(filtration, max_dim) + def extend_filtration(self): """ Extend filtration for computing extended persistence. This function only uses the filtration values at the 0-dimensional simplices, and computes the extended persistence diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 2137d822..1ca84c10 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -340,3 +340,25 @@ def test_simplices_iterator(): assert st.find(simplex[0]) == True print("filtration is: ", simplex[1]) assert st.filtration(simplex[0]) == simplex[1] + +def test_reset_filtration(): + st = SimplexTree() + + assert st.insert([0, 1, 2], 3.) == True + assert st.insert([0, 3], 2.) == True + assert st.insert([3, 4, 5], 3.) == True + assert st.insert([0, 1, 6, 7], 4.) == True + + for simplex in st.get_simplices(): + assert st.filtration(simplex[0]) >= 0. + + # dimension until 5 even if simplex tree is of dimension 3 to test the limits + for dimension in range(0, 6): + st.reset_filtration(0., dimension) + for simplex in st.get_skeleton(3): + print(simplex) + if len(simplex[0]) > (dimension + 1): + assert st.filtration(simplex[0]) >= 1. + else: + assert st.filtration(simplex[0]) == 0. + -- cgit v1.2.3 From 76a61bcd3279a98bd84856b011869a0be2ba99cd Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Thu, 30 Jul 2020 12:36:16 +0200 Subject: collapse edges for python simplex tree --- .../example/rips_complex_edge_collapse_example.py | 65 ++++++++++++++++++++++ src/python/gudhi/simplex_tree.pxd | 1 + src/python/gudhi/simplex_tree.pyx | 63 +++++++++++++-------- src/python/include/Simplex_tree_interface.h | 29 ++++++++++ src/python/test/test_simplex_tree.py | 16 ++++++ 5 files changed, 151 insertions(+), 23 deletions(-) create mode 100755 src/python/example/rips_complex_edge_collapse_example.py (limited to 'src/python/test') diff --git a/src/python/example/rips_complex_edge_collapse_example.py b/src/python/example/rips_complex_edge_collapse_example.py new file mode 100755 index 00000000..e352c155 --- /dev/null +++ b/src/python/example/rips_complex_edge_collapse_example.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +import gudhi +import matplotlib.pyplot as plt +import time + +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Vincent Rouvreau + + Copyright (C) 2016 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +__author__ = "Vincent Rouvreau" +__copyright__ = "Copyright (C) 2020 Inria" +__license__ = "MIT" + + +print("#####################################################################") +print("RipsComplex (only the one-skeleton) creation from tore3D_300.off file") + +off_file = gudhi.__root_source_dir__ + '/data/points/tore3D_300.off' +point_cloud = gudhi.read_points_from_off_file(off_file = off_file) +rips_complex = gudhi.RipsComplex(points=point_cloud, max_edge_length=12.0) +simplex_tree = rips_complex.create_simplex_tree(max_dimension=1) +result_str = '1. Rips complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \ + repr(simplex_tree.num_simplices()) + ' simplices - ' + \ + repr(simplex_tree.num_vertices()) + ' vertices.' +print(result_str) + +# Expansion of this one-skeleton would require a lot of memory. Let's collapse it +start = time.process_time() +simplex_tree.collapse_edges() +simplex_tree.expansion(3) +diag = simplex_tree.persistence() +print("Collapse, expansion and persistence computation took ", time.process_time() - start, " sec.") +result_str = '2. Rips complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \ + repr(simplex_tree.num_simplices()) + ' simplices - ' + \ + repr(simplex_tree.num_vertices()) + ' vertices.' +print(result_str) + +# Use subplots to display diagram and density side by side +fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5)) +gudhi.plot_persistence_diagram(diag, axes=axes[0]) +axes[0].set_title("Persistence after 1 collapse") + +# Collapse can be performed several times. Let's collapse it 3 times +start = time.process_time() +simplex_tree.collapse_edges(nb_iterations = 3) +simplex_tree.expansion(3) +diag = simplex_tree.persistence() +print("Collapse, expansion and persistence computation took ", time.process_time() - start, " sec.") +result_str = '3. Rips complex is of dimension ' + repr(simplex_tree.dimension()) + ' - ' + \ + repr(simplex_tree.num_simplices()) + ' simplices - ' + \ + repr(simplex_tree.num_vertices()) + ' vertices.' +print(result_str) + +gudhi.plot_persistence_diagram(diag, axes=axes[1]) +axes[1].set_title("Persistence after 3 more collapses") + +# Plot the 2 persistence diagrams side to side to check the persistence is the same +plt.show() diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index e748ac40..75e94e0b 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -57,6 +57,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": bool make_filtration_non_decreasing() nogil void compute_extended_filtration() nogil vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil + Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil # Iterators over Simplex tree pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) nogil Simplex_tree_simplices_iterator get_simplices_iterator_begin() nogil diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 20e66d9f..236435a7 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -69,7 +69,7 @@ cdef class SimplexTree: this simplicial complex, or +infinity if it is not in the complex. :param simplex: The N-simplex, represented by a list of vertex. - :type simplex: list of int. + :type simplex: list of int :returns: The simplicial complex filtration value. :rtype: float """ @@ -80,7 +80,7 @@ cdef class SimplexTree: given N-simplex. :param simplex: The N-simplex, represented by a list of vertex. - :type simplex: list of int. + :type simplex: list of int :param filtration: The new filtration value. :type filtration: float @@ -153,7 +153,7 @@ cdef class SimplexTree: """This function sets the dimension of the simplicial complex. :param dimension: The new dimension value. - :type dimension: int. + :type dimension: int .. note:: @@ -172,7 +172,7 @@ cdef class SimplexTree: complex or not. :param simplex: The N-simplex to find, represented by a list of vertex. - :type simplex: list of int. + :type simplex: list of int :returns: true if the simplex was found, false otherwise. :rtype: bool """ @@ -186,9 +186,9 @@ cdef class SimplexTree: :param simplex: The N-simplex to insert, represented by a list of vertex. - :type simplex: list of int. + :type simplex: list of int :param filtration: The filtration value of the simplex. - :type filtration: float. + :type filtration: float :returns: true if the simplex was not yet in the complex, false otherwise (whatever its original filtration value). :rtype: bool @@ -228,7 +228,7 @@ cdef class SimplexTree: """This function returns a generator with the (simplices of the) skeleton of a maximum given dimension. :param dimension: The skeleton dimension value. - :type dimension: int. + :type dimension: int :returns: The (simplices of the) skeleton of a maximum dimension. :rtype: generator with tuples(simplex, filtration) """ @@ -243,7 +243,7 @@ cdef class SimplexTree: """This function returns the star of a given N-simplex. :param simplex: The N-simplex, represented by a list of vertex. - :type simplex: list of int. + :type simplex: list of int :returns: The (simplices of the) star of a simplex. :rtype: list of tuples(simplex, filtration) """ @@ -265,10 +265,10 @@ cdef class SimplexTree: given codimension. :param simplex: The N-simplex, represented by a list of vertex. - :type simplex: list of int. + :type simplex: list of int :param codimension: The codimension. If codimension = 0, all cofaces are returned (equivalent of get_star function) - :type codimension: int. + :type codimension: int :returns: The (simplices of the) cofaces of a simplex :rtype: list of tuples(simplex, filtration) """ @@ -290,7 +290,7 @@ cdef class SimplexTree: complex. :param simplex: The N-simplex, represented by a list of vertex. - :type simplex: list of int. + :type simplex: list of int .. note:: @@ -308,7 +308,7 @@ cdef class SimplexTree: """Prune above filtration value given as parameter. :param filtration: Maximum threshold value. - :type filtration: float. + :type filtration: float :returns: The filtration modification information. :rtype: bool @@ -342,7 +342,7 @@ cdef class SimplexTree: 1 when calling the method. :param max_dim: The maximal dimension. - :type max_dim: int. + :type max_dim: int """ cdef int maxdim = max_dim with nogil: @@ -383,12 +383,12 @@ cdef class SimplexTree: :param homology_coeff_field: The homology coefficient field. Must be a prime number. Default value is 11. - :type homology_coeff_field: int. + :type homology_coeff_field: int :param min_persistence: The minimum persistence value (i.e., the absolute value of the difference between the persistence diagram point coordinates) to take into account (strictly greater than min_persistence). Default value is 0.0. Sets min_persistence to -1.0 to see all values. - :type min_persistence: float. + :type min_persistence: float :returns: A list of four persistence diagrams in the format described in :func:`persistence`. The first one is Ordinary, the second one is Relative, the third one is Extended+ and the fourth one is Extended-. See https://link.springer.com/article/10.1007/s10208-008-9027-z and/or section 2.2 in https://link.springer.com/article/10.1007/s10208-017-9370-z for a description of these subtypes. .. note:: @@ -415,12 +415,12 @@ cdef class SimplexTree: :param homology_coeff_field: The homology coefficient field. Must be a prime number. Default value is 11. - :type homology_coeff_field: int. + :type homology_coeff_field: int :param min_persistence: The minimum persistence value to take into account (strictly greater than min_persistence). Default value is 0.0. Set min_persistence to -1.0 to see all values. - :type min_persistence: float. + :type min_persistence: float :param persistence_dim_max: If true, the persistent homology for the maximal dimension in the complex is computed. If false, it is ignored. Default is false. @@ -438,12 +438,12 @@ cdef class SimplexTree: :param homology_coeff_field: The homology coefficient field. Must be a prime number. Default value is 11. - :type homology_coeff_field: int. + :type homology_coeff_field: int :param min_persistence: The minimum persistence value to take into account (strictly greater than min_persistence). Default value is 0.0. Sets min_persistence to -1.0 to see all values. - :type min_persistence: float. + :type min_persistence: float :param persistence_dim_max: If true, the persistent homology for the maximal dimension in the complex is computed. If false, it is ignored. Default is false. @@ -478,10 +478,10 @@ cdef class SimplexTree: :param from_value: The persistence birth limit to be added in the numbers (persistent birth <= from_value). - :type from_value: float. + :type from_value: float :param to_value: The persistence death limit to be added in the numbers (persistent death > to_value). - :type to_value: float. + :type to_value: float :returns: The persistent Betti numbers ([B0, B1, ..., Bn]). :rtype: list of int @@ -498,7 +498,7 @@ cdef class SimplexTree: complex in a specific dimension. :param dimension: The specific dimension. - :type dimension: int. + :type dimension: int :returns: The persistence intervals. :rtype: numpy array of dimension 2 @@ -527,7 +527,7 @@ cdef class SimplexTree: complex in a user given file name. :param persistence_file: Name of the file. - :type persistence_file: string. + :type persistence_file: string :note: intervals_in_dim function requires :func:`compute_persistence` @@ -581,3 +581,20 @@ cdef class SimplexTree: infinite0 = np_array(next(l)) infinites = [np_array(d).reshape(-1,2) for d in l] return (normal0, normals, infinite0, infinites) + + def collapse_edges(self, nb_iterations = 1): + """Assuming the simplex tree is a 1-skeleton graph, this function collapse edges and resets the simplex tree + from the remaining edges. + A good candidate is to build a simplex tree on top of a :class:`~gudhi.RipsComplex` of dimension 1 before + collapsing edges. + + :param nb_iterations: The number of edge collapse iterations to perform. Default is 1. + :type nb_iterations: int + """ + # Backup old pointer + cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr() + # New pointer is a new collapsed simplex tree + self.thisptr = (self.get_ptr().collapse_edges(nb_iterations)) + # Delete old pointer + if ptr != NULL: + del ptr diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index 56d7c41d..7500098d 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -15,10 +15,12 @@ #include #include #include +#include #include #include #include // std::pair +#include namespace Gudhi { @@ -157,6 +159,33 @@ class Simplex_tree_interface : public Simplex_tree { return new_dgm; } + Simplex_tree_interface* collapse_edges(int nb_collapse_iteration) { + using Filtered_edge = std::tuple; + std::vector edges; + for (Simplex_handle sh : Base::skeleton_simplex_range(1)) { + if (Base::dimension(sh) == 1) { + typename Base::Simplex_vertex_range rg = Base::simplex_vertex_range(sh); + std::vector rips_edge(rg.begin(), rg.end()); + edges.push_back(std::make_tuple(rips_edge[0], rips_edge[1], Base::filtration(sh))); + } + } + + std::vector remaining_edges; + for (int iteration = 0; iteration < nb_collapse_iteration; iteration++) { + remaining_edges = Gudhi::collapse::flag_complex_collapse_edges(edges); + edges = std::move(remaining_edges); + remaining_edges.clear(); + } + Simplex_tree_interface* collapsed_stree_ptr = new Simplex_tree_interface(); + for (auto remaining_edge : edges) { + collapsed_stree_ptr->insert({std::get<0>(remaining_edge)}, 0.); + collapsed_stree_ptr->insert({std::get<1>(remaining_edge)}, 0.); + collapsed_stree_ptr->insert({std::get<0>(remaining_edge), std::get<1>(remaining_edge)}, std::get<2>(remaining_edge)); + } + collapsed_stree_ptr->initialize_filtration(); + return collapsed_stree_ptr; + } + // Iterator over the simplex tree Complex_simplex_iterator get_simplices_iterator_begin() { // this specific case works because the range is just a pair of iterators - won't work if range was a vector diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 2137d822..30a8f5e0 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -340,3 +340,19 @@ def test_simplices_iterator(): assert st.find(simplex[0]) == True print("filtration is: ", simplex[1]) assert st.filtration(simplex[0]) == simplex[1] + +def test_collapse_edges(): + st = SimplexTree() + + assert st.insert([0, 1], filtration=1.0) == True + assert st.insert([1, 2], filtration=1.0) == True + assert st.insert([2, 3], filtration=1.0) == True + assert st.insert([0, 3], filtration=1.0) == True + assert st.insert([0, 2], filtration=2.0) == True + assert st.insert([1, 3], filtration=2.0) == True + + assert st.num_simplices() == 10 + + st.collapse_edges() + assert st.num_simplices() == 9 + assert st.find([1, 3]) == False -- cgit v1.2.3 From 39fba06ef758483bc237b9375413974c3bbc16e4 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Fri, 31 Jul 2020 17:34:47 +0200 Subject: code review: collapse edges should copy the 0-skeleton. A test was added --- src/python/include/Simplex_tree_interface.h | 7 +++++-- src/python/test/test_simplex_tree.py | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) (limited to 'src/python/test') diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index ad0f9a28..f786ad6e 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -176,9 +176,12 @@ class Simplex_tree_interface : public Simplex_tree { edges = Gudhi::collapse::flag_complex_collapse_edges(edges); } Simplex_tree_interface* collapsed_stree_ptr = new Simplex_tree_interface(); + // Copy the original 0-skeleton + for (Simplex_handle sh : Base::skeleton_simplex_range(0)) { + collapsed_stree_ptr->insert({*(Base::simplex_vertex_range(sh).begin())}, Base::filtration(sh)); + } + // Insert remaining edges for (auto remaining_edge : edges) { - collapsed_stree_ptr->insert({std::get<0>(remaining_edge)}, 0.); - collapsed_stree_ptr->insert({std::get<1>(remaining_edge)}, 0.); collapsed_stree_ptr->insert({std::get<0>(remaining_edge), std::get<1>(remaining_edge)}, std::get<2>(remaining_edge)); } return collapsed_stree_ptr; diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 30a8f5e0..83be0602 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -356,3 +356,5 @@ def test_collapse_edges(): st.collapse_edges() assert st.num_simplices() == 9 assert st.find([1, 3]) == False + for simplex in st.get_skeleton(0): + assert simplex[1] == 1. -- cgit v1.2.3 From 4cbe978277d8d4fd81ef91bf26f65b5d9b279cf0 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Tue, 4 Aug 2020 16:55:52 +0200 Subject: Fix python alpha complex for conda package --- src/python/test/test_alpha_complex.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) (limited to 'src/python/test') diff --git a/src/python/test/test_alpha_complex.py b/src/python/test/test_alpha_complex.py index a4ee260b..814f8289 100755 --- a/src/python/test/test_alpha_complex.py +++ b/src/python/test/test_alpha_complex.py @@ -198,8 +198,7 @@ def test_delaunay_complex(): _delaunay_complex(precision) def _3d_points_on_a_plane(precision, default_filtration_value): - alpha = gd.AlphaComplex(off_file=gd.__root_source_dir__ + '/data/points/alphacomplexdoc.off', - precision = precision) + alpha = gd.AlphaComplex(off_file='alphacomplexdoc.off', precision = precision) simplex_tree = alpha.create_simplex_tree(default_filtration_value = default_filtration_value) assert simplex_tree.dimension() == 2 @@ -207,6 +206,18 @@ def _3d_points_on_a_plane(precision, default_filtration_value): assert simplex_tree.num_simplices() == 25 def test_3d_points_on_a_plane(): + off_file = open("alphacomplexdoc.off", "w") + off_file.write("OFF \n" \ + "7 0 0 \n" \ + "1.0 1.0 0.0\n" \ + "7.0 0.0 0.0\n" \ + "4.0 6.0 0.0\n" \ + "9.0 6.0 0.0\n" \ + "0.0 14.0 0.0\n" \ + "2.0 19.0 0.0\n" \ + "9.0 17.0 0.0\n" ) + off_file.close() + for default_filtration_value in [True, False]: for precision in ['fast', 'safe', 'exact']: _3d_points_on_a_plane(precision, default_filtration_value) -- cgit v1.2.3 From 8f9c065df7f4629555ef09292c14c293891f1bdc Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Wed, 12 Aug 2020 10:03:26 +0200 Subject: code review: Add test to get boundaries from a vertex --- src/python/test/test_simplex_tree.py | 1 + 1 file changed, 1 insertion(+) (limited to 'src/python/test') diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 7c49cb1a..036e88df 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -349,3 +349,4 @@ def test_boundaries_iterator(): assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)] assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)] + assert list(st.get_boundaries([2])) == [] -- cgit v1.2.3 From 458bc2dcf5044e1d5fde5326b2be35e526abd457 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Wed, 12 Aug 2020 13:06:03 +0200 Subject: code review: boundaries uses only once find and return a pair of iterator. Exception is raised when not found. tested --- src/python/gudhi/simplex_tree.pxd | 3 +-- src/python/gudhi/simplex_tree.pyx | 14 ++++++++------ src/python/include/Simplex_tree_interface.h | 12 +++++------- src/python/test/test_simplex_tree.py | 9 +++++++++ 4 files changed, 23 insertions(+), 15 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index a64599e7..e1b03391 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -71,8 +71,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": vector[Simplex_tree_simplex_handle].const_iterator get_filtration_iterator_end() nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil - Simplex_tree_boundary_iterator get_boundary_iterator_begin(vector[int] simplex) nogil - Simplex_tree_boundary_iterator get_boundary_iterator_end(vector[int] simplex) nogil + pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] get_boundary_iterators(vector[int] simplex) nogil except + cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface>": diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 3ebae923..bc5b43f4 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -293,12 +293,14 @@ cdef class SimplexTree: :returns: The (simplices of the) boundary of a simplex :rtype: generator with tuples(simplex, filtration) """ - cdef Simplex_tree_boundary_iterator it = self.get_ptr().get_boundary_iterator_begin(simplex) - cdef Simplex_tree_boundary_iterator end = self.get_ptr().get_boundary_iterator_end(simplex) - - while it != end: - yield self.get_ptr().get_simplex_and_filtration(dereference(it)) - preincrement(it) + cdef pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] it = self.get_ptr().get_boundary_iterators(simplex) + + try: + while it.first != it.second: + yield self.get_ptr().get_simplex_and_filtration(dereference(it.first)) + preincrement(it.first) + except RuntimeError: + print("simplex not found - cannot find boundaries") def remove_maximal_simplex(self, simplex): diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index c4f18eeb..2c695d1b 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -190,14 +190,12 @@ class Simplex_tree_interface : public Simplex_tree { return Base::skeleton_simplex_range(dimension).end(); } - Boundary_simplex_iterator get_boundary_iterator_begin(const Simplex& simplex) { + std::pair get_boundary_iterators(const Simplex& simplex) { + auto bd_sh = Base::find(simplex); + if (bd_sh == Base::null_simplex()) + throw std::runtime_error("simplex not found - cannot find boundaries"); // this specific case works because the range is just a pair of iterators - won't work if range was a vector - return Base::boundary_simplex_range(Base::find(simplex)).begin(); - } - - Boundary_simplex_iterator get_boundary_iterator_end(const Simplex& simplex) { - // this specific case works because the range is just a pair of iterators - won't work if range was a vector - return Base::boundary_simplex_range(Base::find(simplex)).end(); + return std::make_pair(Base::boundary_simplex_range(bd_sh).begin(), Base::boundary_simplex_range(bd_sh).end()); } }; diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 036e88df..828400fb 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -350,3 +350,12 @@ def test_boundaries_iterator(): assert list(st.get_boundaries([1, 2, 3])) == [([1, 2], 1.0), ([1, 3], 1.0), ([2, 3], 1.0)] assert list(st.get_boundaries([2, 3, 4])) == [([2, 3], 1.0), ([2, 4], 2.0), ([3, 4], 2.0)] assert list(st.get_boundaries([2])) == [] + + with pytest.raises(RuntimeError): + list(st.get_boundaries([])) + + with pytest.raises(RuntimeError): + list(st.get_boundaries([0, 4])) # (0, 4) does not exist + + with pytest.raises(RuntimeError): + list(st.get_boundaries([6])) # (6) does not exist -- cgit v1.2.3 From ddb2118f0af865588d7c14b88171dc04bb27c529 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Tue, 18 Aug 2020 14:38:31 +0200 Subject: reset_filtration from a dimension (instead of 'until') --- src/Simplex_tree/include/gudhi/Simplex_tree.h | 31 ++++++++++++++---------- src/Simplex_tree/test/simplex_tree_unit_test.cpp | 14 +++++++---- src/python/gudhi/simplex_tree.pyx | 17 +++++++------ src/python/test/test_simplex_tree.py | 9 ++++--- 4 files changed, 41 insertions(+), 30 deletions(-) (limited to 'src/python/test') diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree.h b/src/Simplex_tree/include/gudhi/Simplex_tree.h index adc8e801..89b4a5df 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree.h @@ -1668,31 +1668,36 @@ class Simplex_tree { } public: - /** \brief This function resets filtration value until a given dimension. + /** \brief This function resets filtration value from a given dimension. Resets all the Simplex_tree when + * `min_dim = 0`. * @param[in] filt_value The new filtration value. - * @param[in] max_dim The maximal dimension. + * @param[in] min_dim The minimal dimension. */ - void reset_filtration(Filtration_value filt_value, int max_dim) { + void reset_filtration(Filtration_value filt_value, int min_dim) { for (auto& simplex : root_.members()) { - simplex.second.assign_filtration(filt_value); - if (has_children(&simplex) && max_dim > 0) { - rec_reset_filtration(simplex.second.children(), filt_value, (max_dim - 1)); + if (min_dim <= 0) { + simplex.second.assign_filtration(filt_value); + } + if (has_children(&simplex)) { + rec_reset_filtration(simplex.second.children(), filt_value, min_dim); } } clear_filtration(); // Drop the cache. } private: - /** \brief Recursively resets filtration value until a given dimension. + /** \brief Recursively resets filtration value from a given dimension. * @param[in] sib Siblings to be parsed. * @param[in] filt_value The new filtration value. - * @param[in] max_dim The maximal dimension. + * @param[in] min_dim The maximal dimension. */ - void rec_reset_filtration(Siblings * sib, Filtration_value filt_value, int max_dim) { - for (auto& simplex : sib->members()) { - simplex.second.assign_filtration(filt_value); - if (has_children(&simplex) && max_dim > 0) { - rec_reset_filtration(simplex.second.children(), filt_value, (max_dim - 1)); + void rec_reset_filtration(Siblings * sib, Filtration_value filt_value, int min_dim) { + for (auto sh = sib->members().begin(); sh != sib->members().end(); ++sh) { + if (min_dim <= dimension(sh)) { + sh->second.assign_filtration(filt_value); + } + if (has_children(sh)) { + rec_reset_filtration(sh->second.children(), filt_value, min_dim); } } } diff --git a/src/Simplex_tree/test/simplex_tree_unit_test.cpp b/src/Simplex_tree/test/simplex_tree_unit_test.cpp index 9780f5b0..bdd41d34 100644 --- a/src/Simplex_tree/test/simplex_tree_unit_test.cpp +++ b/src/Simplex_tree/test/simplex_tree_unit_test.cpp @@ -965,21 +965,25 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(simplex_tree_reset_filtration, typeST, list_of_tes for (auto vertex : st.simplex_vertex_range(f_simplex)) { std::clog << vertex << ","; } - std::clog << ") - filtration =" << st.filtration(f_simplex) << std::endl; + std::clog << ") - filtration = " << st.filtration(f_simplex); + std::clog << " - dimension = " << st.dimension(f_simplex) << std::endl; + // Guaranteed by construction BOOST_CHECK(st.filtration(f_simplex) >= 2.); } // dimension until 5 even if simplex tree is of dimension 3 to test the limits - for(int dimension = 0; dimension < 6; dimension ++) { + for(int dimension = 5; dimension >= 0; dimension --) { + std::clog << "### reset_filtration - dimension = " << dimension << "\n"; st.reset_filtration(0., dimension); for (auto f_simplex : st.skeleton_simplex_range(3)) { std::clog << "vertex = ("; for (auto vertex : st.simplex_vertex_range(f_simplex)) { std::clog << vertex << ","; } - std::clog << ") - filtration =" << st.filtration(f_simplex) << std::endl; - if (st.dimension(f_simplex) > dimension) - BOOST_CHECK(st.filtration(f_simplex) >= 1.); + std::clog << ") - filtration = " << st.filtration(f_simplex); + std::clog << " - dimension = " << st.dimension(f_simplex) << std::endl; + if (st.dimension(f_simplex) < dimension) + BOOST_CHECK(st.filtration(f_simplex) >= 2.); else BOOST_CHECK(st.filtration(f_simplex) == 0.); } diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index b7682693..657d55be 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -328,7 +328,7 @@ cdef class SimplexTree: return self.get_ptr().prune_above_filtration(filtration) def expansion(self, max_dim): - """Expands the Simplex_tree containing only its one skeleton + """Expands the simplex tree containing only its one skeleton until dimension max_dim. The expanded simplicial complex until dimension :math:`d` @@ -338,7 +338,7 @@ cdef class SimplexTree: The filtration value assigned to a simplex is the maximal filtration value of one of its edges. - The Simplex_tree must contain no simplex of dimension bigger than + The simplex tree must contain no simplex of dimension bigger than 1 when calling the method. :param max_dim: The maximal dimension. @@ -358,15 +358,16 @@ cdef class SimplexTree: """ return self.get_ptr().make_filtration_non_decreasing() - def reset_filtration(self, filtration, max_dim): - """This function resets filtration value until a given dimension. + def reset_filtration(self, filtration, min_dim): + """This function resets filtration value from a given dimension. + Resets all the simplex tree when `min_dim = 0`. :param filtration: New threshold value. :type filtration: float. - :param max_dim: The maximal dimension. + :param max_dim: The minimal dimension. :type max_dim: int. """ - self.get_ptr().reset_filtration(filtration, max_dim) + self.get_ptr().reset_filtration(filtration, min_dim) def extend_filtration(self): """ Extend filtration for computing extended persistence. This function only uses the @@ -376,14 +377,14 @@ cdef class SimplexTree: .. note:: Note that after calling this function, the filtration - values are actually modified within the Simplex_tree. + values are actually modified within the simplex tree. The function :func:`extended_persistence` retrieves the original values. .. note:: Note that this code creates an extra vertex internally, so you should make sure that - the Simplex_tree does not contain a vertex with the largest possible value (i.e., 4294967295). + the simplex tree does not contain a vertex with the largest possible value (i.e., 4294967295). """ self.get_ptr().compute_extended_filtration() diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 6f1d01cc..ac2b59c7 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -367,15 +367,16 @@ def test_reset_filtration(): assert st.insert([3, 4, 5], 3.) == True assert st.insert([0, 1, 6, 7], 4.) == True + # Guaranteed by construction for simplex in st.get_simplices(): - assert st.filtration(simplex[0]) >= 0. + assert st.filtration(simplex[0]) >= 2. # dimension until 5 even if simplex tree is of dimension 3 to test the limits - for dimension in range(0, 6): + for dimension in range(5, -1, -1): st.reset_filtration(0., dimension) for simplex in st.get_skeleton(3): print(simplex) - if len(simplex[0]) > (dimension + 1): - assert st.filtration(simplex[0]) >= 1. + if len(simplex[0]) < (dimension) + 1: + assert st.filtration(simplex[0]) >= 2. else: assert st.filtration(simplex[0]) == 0. -- cgit v1.2.3 From e7b7947adf13ec1dcb8c126a4373fa29baaecb63 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Tue, 29 Sep 2020 13:23:56 +0200 Subject: Added tests for wasserstein distance with tensorflow --- .../modules/GUDHI_third_party_libraries.cmake | 1 + src/python/CMakeLists.txt | 8 +++++++ src/python/doc/installation.rst | 5 +++++ src/python/test/test_wasserstein_with_tensors.py | 25 ++++++++++++++++++++++ 4 files changed, 39 insertions(+) create mode 100755 src/python/test/test_wasserstein_with_tensors.py (limited to 'src/python/test') diff --git a/src/cmake/modules/GUDHI_third_party_libraries.cmake b/src/cmake/modules/GUDHI_third_party_libraries.cmake index 1fbc4244..9dadac4f 100644 --- a/src/cmake/modules/GUDHI_third_party_libraries.cmake +++ b/src/cmake/modules/GUDHI_third_party_libraries.cmake @@ -155,6 +155,7 @@ if( PYTHONINTERP_FOUND ) find_python_module("pykeops") find_python_module("eagerpy") find_python_module_no_version("hnswlib") + find_python_module("tensorflow") endif() if(NOT GUDHI_PYTHON_PATH) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 4f26481e..cc71503f 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -103,6 +103,9 @@ if(PYTHONINTERP_FOUND) if(EAGERPY_FOUND) add_gudhi_debug_info("EagerPy version ${EAGERPY_VERSION}") endif() + if(TENSORFLOW_FOUND) + add_gudhi_debug_info("TensorFlow version ${TENSORFLOW_VERSION}") + endif() set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_RESULT_OF_USE_DECLTYPE', ") set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_ALL_NO_LIB', ") @@ -501,6 +504,11 @@ if(PYTHONINTERP_FOUND) endif() add_gudhi_py_test(test_wasserstein_barycenter) endif() + if(OT_FOUND) + if(TENSORFLOW_FOUND AND EAGERPY_FOUND) + add_gudhi_py_test(test_wasserstein_with_tensors) + endif() + endif() # Representations if(SKLEARN_FOUND AND MATPLOTLIB_FOUND) diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst index 2f161d66..66efe45a 100644 --- a/src/python/doc/installation.rst +++ b/src/python/doc/installation.rst @@ -394,6 +394,11 @@ mathematics, science, and engineering. :class:`~gudhi.point_cloud.knn.KNearestNeighbors` can use the Python package `SciPy `_ as a backend if explicitly requested. +TensorFlow +---------- + +`TensorFlow `_ is currently only used in some automatic differentiation tests. + Bug reports and contributions ***************************** diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py new file mode 100755 index 00000000..8957705d --- /dev/null +++ b/src/python/test/test_wasserstein_with_tensors.py @@ -0,0 +1,25 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Mathieu Carriere + + Copyright (C) 2020 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.wasserstein import wasserstein_distance as pot +import numpy as np + +def test_wasserstein_distance_grad_tensorflow(): + import tensorflow as tf + + with tf.GradientTape() as tape: + diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True)) + diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True)) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + assert dist45 == 3. + + grads = tape.gradient(dist45, [diag4, diag5]) + assert np.array_equal(grads[0].values, [[-1., -1.]]) + assert np.array_equal(grads[1].values, [[1., 1.], [-1., 1.]]) \ No newline at end of file -- cgit v1.2.3 From f0beb329f5a1767e4e0a0575ef3e078bf4563a44 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Mon, 5 Oct 2020 11:12:44 +0200 Subject: code review: move test_wasserstein_distance_grad from test_wasserstein_distance.py to test_wasserstein_with_tensors.py --- src/python/CMakeLists.txt | 5 +++-- src/python/test/test_wasserstein_distance.py | 24 ---------------------- src/python/test/test_wasserstein_with_tensors.py | 26 ++++++++++++++++++++++-- 3 files changed, 27 insertions(+), 28 deletions(-) (limited to 'src/python/test') diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index cc71503f..c09996fe 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -499,13 +499,14 @@ if(PYTHONINTERP_FOUND) # Wasserstein if(OT_FOUND AND PYBIND11_FOUND) - if(TORCH_FOUND AND EAGERPY_FOUND) + # EagerPy dependency because of enable_autodiff=True + if(EAGERPY_FOUND) add_gudhi_py_test(test_wasserstein_distance) endif() add_gudhi_py_test(test_wasserstein_barycenter) endif() if(OT_FOUND) - if(TENSORFLOW_FOUND AND EAGERPY_FOUND) + if(TORCH_FOUND AND TENSORFLOW_FOUND AND EAGERPY_FOUND) add_gudhi_py_test(test_wasserstein_with_tensors) endif() endif() diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 90d26809..e3b521d6 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -97,27 +97,3 @@ def test_wasserstein_distance_pot(): def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) - -def test_wasserstein_distance_grad(): - import torch - - diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) - diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) - diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) - assert diag1.grad is None and diag2.grad is None and diag3.grad is None - dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) - dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) - dist12.backward() - dist30.backward() - assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() - diag4 = torch.tensor([[0., 10.]], requires_grad=True) - diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) - dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) - assert dist45 == 3. - dist45.backward() - assert np.array_equal(diag4.grad, [[-1., -1.]]) - assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) - diag6 = torch.tensor([[5., 10.]], requires_grad=True) - pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() - # https://github.com/jonasrauber/eagerpy/issues/6 - # assert np.array_equal(diag6.grad, [[0., 0.]]) diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py index 8957705d..e3f1411a 100755 --- a/src/python/test/test_wasserstein_with_tensors.py +++ b/src/python/test/test_wasserstein_with_tensors.py @@ -10,10 +10,32 @@ from gudhi.wasserstein import wasserstein_distance as pot import numpy as np +import torch +import tensorflow as tf -def test_wasserstein_distance_grad_tensorflow(): - import tensorflow as tf +def test_wasserstein_distance_grad(): + diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) + diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + assert diag1.grad is None and diag2.grad is None and diag3.grad is None + dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) + dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist12.backward() + dist30.backward() + assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() + diag4 = torch.tensor([[0., 10.]], requires_grad=True) + diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + assert dist45 == 3. + dist45.backward() + assert np.array_equal(diag4.grad, [[-1., -1.]]) + assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) + diag6 = torch.tensor([[5., 10.]], requires_grad=True) + pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() + # https://github.com/jonasrauber/eagerpy/issues/6 + # assert np.array_equal(diag6.grad, [[0., 0.]]) +def test_wasserstein_distance_grad_tensorflow(): with tf.GradientTape() as tape: diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True)) diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True)) -- cgit v1.2.3 From 0022442a303f297ac773e262abd2661d2ce0a614 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Sun, 8 Nov 2020 22:08:21 +0100 Subject: Test BettiCurve with infinite value + black reformatting --- src/python/test/test_representations.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) (limited to 'src/python/test') diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index e5c211a0..43c914f3 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -39,11 +39,11 @@ def test_multiple(): d2 = BottleneckDistance(epsilon=0.00001).fit_transform(l1) d3 = pairwise_persistence_diagram_distances(l1, l1b, e=0.00001, n_jobs=4) assert d1 == pytest.approx(d2) - assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal) + assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal) d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2) d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1) print(d1.shape, d2.shape) - assert d1 == pytest.approx(d2, rel=.02) + assert d1 == pytest.approx(d2, rel=0.02) def test_dummy_atol(): @@ -53,8 +53,22 @@ def test_dummy_atol(): for weighting_method in ["cloud", "iidproba"]: for contrast in ["gaussian", "laplacian", "indicator"]: - atol_vectoriser = Atol(quantiser=KMeans(n_clusters=1, random_state=202006), weighting_method=weighting_method, contrast=contrast) + atol_vectoriser = Atol( + quantiser=KMeans(n_clusters=1, random_state=202006), + weighting_method=weighting_method, + contrast=contrast, + ) atol_vectoriser.fit([a, b, c]) atol_vectoriser(a) atol_vectoriser.transform(X=[a, b, c]) + +from gudhi.representations.vector_methods import BettiCurve + + +def test_infinity(): + a = np.array([[1.0, 8.0], [2.0, np.inf], [3.0, 4.0]]) + c = BettiCurve(20, [0.0, 10.0])(a) + assert c[1] == 0 + assert c[7] == 3 + assert c[9] == 2 -- cgit v1.2.3 From 0a071114ad08d2ce149f8b484dd8ff1b96b61fb1 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 9 Nov 2020 22:55:00 +0100 Subject: Don't test the equality case in sparsify_point_set. sqrt. --- src/Subsampling/include/gudhi/sparsify_point_set.h | 6 ++++-- src/python/gudhi/subsampling.pyx | 2 +- src/python/test/test_subsampling.py | 16 ++++++++++------ 3 files changed, 15 insertions(+), 9 deletions(-) (limited to 'src/python/test') diff --git a/src/Subsampling/include/gudhi/sparsify_point_set.h b/src/Subsampling/include/gudhi/sparsify_point_set.h index 78e0da4a..afa6d45a 100644 --- a/src/Subsampling/include/gudhi/sparsify_point_set.h +++ b/src/Subsampling/include/gudhi/sparsify_point_set.h @@ -29,7 +29,7 @@ namespace subsampling { * \ingroup subsampling * \brief Outputs a subset of the input points so that the * squared distance between any two points - * is greater than or equal to `min_squared_dist`. + * is greater than `min_squared_dist`. * * \tparam Kernel must be a model of the SearchTraits @@ -53,6 +53,7 @@ sparsify_point_set( OutputIterator output_it) { typedef typename Gudhi::spatial_searching::Kd_tree_search< Kernel, Point_range> Points_ds; + using std::sqrt; #ifdef GUDHI_SUBSAMPLING_PROFILING Gudhi::Clock t; @@ -73,7 +74,8 @@ sparsify_point_set( // If another point Q is closer that min_squared_dist, mark Q to be dropped auto drop = [&dropped_points] (std::ptrdiff_t neighbor_point_idx) { dropped_points[neighbor_point_idx] = true; }; - points_ds.all_near_neighbors(pt, min_squared_dist, boost::make_function_output_iterator(std::ref(drop))); + // FIXME: what if FT does not support sqrt? + points_ds.all_near_neighbors(pt, sqrt(min_squared_dist), boost::make_function_output_iterator(std::ref(drop))); } #ifdef GUDHI_SUBSAMPLING_PROFILING diff --git a/src/python/gudhi/subsampling.pyx b/src/python/gudhi/subsampling.pyx index b11d07e5..46f32335 100644 --- a/src/python/gudhi/subsampling.pyx +++ b/src/python/gudhi/subsampling.pyx @@ -105,7 +105,7 @@ def pick_n_random_points(points=None, off_file='', nb_points=0): def sparsify_point_set(points=None, off_file='', min_squared_dist=0.0): """Outputs a subset of the input points so that the squared distance - between any two points is greater than or equal to min_squared_dist. + between any two points is greater than min_squared_dist. :param points: The input point set. :type points: Iterable[Iterable[float]] diff --git a/src/python/test/test_subsampling.py b/src/python/test/test_subsampling.py index 31f64e32..4019852e 100755 --- a/src/python/test/test_subsampling.py +++ b/src/python/test/test_subsampling.py @@ -141,12 +141,16 @@ def test_simple_sparsify_points(): # assert gudhi.sparsify_point_set(points = [], min_squared_dist = 0.0) == [] # assert gudhi.sparsify_point_set(points = [], min_squared_dist = 10.0) == [] assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=0.0) == point_set - assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=1.0) == point_set - assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=2.0) == [ + assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=0.999) == point_set + assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=1.001) == [ [0, 1], [1, 0], ] - assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=2.01) == [[0, 1]] + assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=1.999) == [ + [0, 1], + [1, 0], + ] + assert gudhi.sparsify_point_set(points=point_set, min_squared_dist=2.001) == [[0, 1]] assert ( len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=0.0)) @@ -157,11 +161,11 @@ def test_simple_sparsify_points(): == 5 ) assert ( - len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=40.0)) + len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=40.1)) == 4 ) assert ( - len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=90.0)) + len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=89.9)) == 3 ) assert ( @@ -169,7 +173,7 @@ def test_simple_sparsify_points(): == 2 ) assert ( - len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=325.0)) + len(gudhi.sparsify_point_set(off_file="subsample.off", min_squared_dist=324.9)) == 2 ) assert ( -- cgit v1.2.3 From 53376fde3f35576af18fac33d731e8398da7522e Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Fri, 13 Nov 2020 12:39:27 +0100 Subject: Test with negative coordinates --- src/python/test/test_bottleneck_distance.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'src/python/test') diff --git a/src/python/test/test_bottleneck_distance.py b/src/python/test/test_bottleneck_distance.py index 6915bea8..07fcc9cc 100755 --- a/src/python/test/test_bottleneck_distance.py +++ b/src/python/test/test_bottleneck_distance.py @@ -25,3 +25,15 @@ def test_basic_bottleneck(): assert gudhi.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, abs=0.1) assert gudhi.hera.bottleneck_distance(diag1, diag2, 0) == 0.75 assert gudhi.hera.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, rel=0.1) + + import numpy as np + + # Translating both diagrams along the diagonal should not affect the distance, + # checks that negative numbers are not an issue + diag1 = np.array(diag1) - 100 + diag2 = np.array(diag2) - 100 + + assert gudhi.bottleneck_distance(diag1, diag2) == 0.75 + assert gudhi.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, abs=0.1) + assert gudhi.hera.bottleneck_distance(diag1, diag2, 0) == 0.75 + assert gudhi.hera.bottleneck_distance(diag1, diag2, 0.1) == pytest.approx(0.75, rel=0.1) -- cgit v1.2.3 From 55e2b2e55bc50a7cfea9ca1edfca632488cf016a Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Tue, 8 Dec 2020 15:49:43 +0100 Subject: Make representations tests work if CGAL and/or POT is not there --- .../diagram_vectorizations_distances_kernels.py | 19 +++++++++++++------ src/python/test/test_representations.py | 10 +++++++++- 2 files changed, 22 insertions(+), 7 deletions(-) (limited to 'src/python/test') diff --git a/src/python/example/diagram_vectorizations_distances_kernels.py b/src/python/example/diagram_vectorizations_distances_kernels.py index c4a71a7a..2801576e 100755 --- a/src/python/example/diagram_vectorizations_distances_kernels.py +++ b/src/python/example/diagram_vectorizations_distances_kernels.py @@ -5,11 +5,11 @@ import numpy as np from sklearn.kernel_approximation import RBFSampler from sklearn.preprocessing import MinMaxScaler -from gudhi.representations import DiagramSelector, Clamping, Landscape, Silhouette, BettiCurve, ComplexPolynomial,\ +from gudhi.representations import (DiagramSelector, Clamping, Landscape, Silhouette, BettiCurve, ComplexPolynomial,\ TopologicalVector, DiagramScaler, BirthPersistenceTransform,\ PersistenceImage, PersistenceWeightedGaussianKernel, Entropy, \ PersistenceScaleSpaceKernel, SlicedWassersteinDistance,\ - SlicedWassersteinKernel, BottleneckDistance, PersistenceFisherKernel, WassersteinDistance + SlicedWassersteinKernel, PersistenceFisherKernel, WassersteinDistance) D1 = np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.], [0., np.inf], [5., np.inf]]) @@ -93,14 +93,21 @@ print("SW distance is " + str(sW(D1, D2))) SW = SlicedWassersteinKernel(num_directions=100, bandwidth=1.) print("SW kernel is " + str(SW(D1, D2))) -W = WassersteinDistance(order=2, internal_p=2, mode="pot") -print("Wasserstein distance (POT) is " + str(W(D1, D2))) +try: + W = WassersteinDistance(order=2, internal_p=2, mode="pot") + print("Wasserstein distance (POT) is " + str(W(D1, D2))) +except ImportError: + print("WassersteinDistance (POT) is not available, you may be missing pot.") W = WassersteinDistance(order=2, internal_p=2, mode="hera", delta=0.0001) print("Wasserstein distance (hera) is " + str(W(D1, D2))) -W = BottleneckDistance(epsilon=.001) -print("Bottleneck distance is " + str(W(D1, D2))) +try: + from gudhi.representations import BottleneckDistance + W = BottleneckDistance(epsilon=.001) + print("Bottleneck distance is " + str(W(D1, D2))) +except ImportError: + print("BottleneckDistance is not available, you may be missing CGAL.") PF = PersistenceFisherKernel(bandwidth_fisher=1., bandwidth=1.) print("PF kernel is " + str(PF(D1, D2))) diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 43c914f3..8ebd7888 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -41,7 +41,15 @@ def test_multiple(): assert d1 == pytest.approx(d2) assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal) d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2) - d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1) + mode = "" + try: + import ot + mode = "pot" + except ImportError: + print("POT is not available, try with hera") + mode = "hera" + + d2 = WassersteinDistance(order=2, internal_p=2, mode=mode, n_jobs=4).fit(l2).transform(l1) print(d1.shape, d2.shape) assert d1 == pytest.approx(d2, rel=0.02) -- cgit v1.2.3 From 366813c651d871c5b95d9bc8e7ea227f8015fc55 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Tue, 8 Dec 2020 16:35:48 +0100 Subject: rollback test_representation and a CGAL condition to launch it --- src/python/CMakeLists.txt | 2 +- src/python/test/test_representations.py | 10 +--------- 2 files changed, 2 insertions(+), 10 deletions(-) (limited to 'src/python/test') diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 56b6876c..4dca3908 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -504,7 +504,7 @@ if(PYTHONINTERP_FOUND) endif() # Representations - if(SKLEARN_FOUND AND MATPLOTLIB_FOUND) + if(SKLEARN_FOUND AND MATPLOTLIB_FOUND AND NOT CGAL_VERSION VERSION_LESS 4.11.0) add_gudhi_py_test(test_representations) endif() diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 8ebd7888..43c914f3 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -41,15 +41,7 @@ def test_multiple(): assert d1 == pytest.approx(d2) assert d3 == pytest.approx(d2, abs=1e-5) # Because of 0 entries (on the diagonal) d1 = pairwise_persistence_diagram_distances(l1, l2, metric="wasserstein", order=2, internal_p=2) - mode = "" - try: - import ot - mode = "pot" - except ImportError: - print("POT is not available, try with hera") - mode = "hera" - - d2 = WassersteinDistance(order=2, internal_p=2, mode=mode, n_jobs=4).fit(l2).transform(l1) + d2 = WassersteinDistance(order=2, internal_p=2, n_jobs=4).fit(l2).transform(l1) print(d1.shape, d2.shape) assert d1 == pytest.approx(d2, rel=0.02) -- cgit v1.2.3 From fda0084941ece5d41a258d19ca4eb0b3d87384a4 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Wed, 9 Dec 2020 09:41:13 +0100 Subject: Fix #388 --- src/cmake/modules/GUDHI_third_party_libraries.cmake | 1 - src/python/CMakeLists.txt | 1 + src/python/gudhi/simplex_tree.pxd | 2 ++ src/python/gudhi/simplex_tree.pyx | 3 +++ src/python/include/Simplex_tree_interface.h | 10 +++++++++- src/python/test/test_simplex_tree.py | 14 +++++++++----- 6 files changed, 24 insertions(+), 7 deletions(-) (limited to 'src/python/test') diff --git a/src/cmake/modules/GUDHI_third_party_libraries.cmake b/src/cmake/modules/GUDHI_third_party_libraries.cmake index e2684aa4..e1566877 100644 --- a/src/cmake/modules/GUDHI_third_party_libraries.cmake +++ b/src/cmake/modules/GUDHI_third_party_libraries.cmake @@ -58,7 +58,6 @@ endif(WITH_GUDHI_USE_TBB) set(CGAL_WITH_EIGEN3_VERSION 0.0.0) find_package(Eigen3 3.1.0) if (EIGEN3_FOUND) - add_definitions(-DGUDHI_USE_EIGEN3) include( ${EIGEN3_USE_FILE} ) set(CGAL_WITH_EIGEN3_VERSION ${CGAL_VERSION}) endif (EIGEN3_FOUND) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 2d5f793a..45c89609 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -133,6 +133,7 @@ if(PYTHONINTERP_FOUND) add_gudhi_debug_info("Eigen3 version ${EIGEN3_VERSION}") # No problem, even if no CGAL found set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_EIGEN3_ENABLED', ") + set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DGUDHI_USE_EIGEN3', ") endif (EIGEN3_FOUND) set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'off_reader', ") diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index 3c4cbed3..283830ff 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -74,6 +74,8 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] get_boundary_iterators(vector[int] simplex) nogil except + + + cdef int _GUDHI_USE_EIGEN3 cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface>": diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 7d6ab89a..eb44c0fc 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -17,6 +17,9 @@ __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2016 Inria" __license__ = "MIT" +# For unitary tests purpose +__GUDHI_USE_EIGEN3 = _GUDHI_USE_EIGEN3 + # SimplexTree python interface cdef class SimplexTree: """The simplex tree is an efficient and flexible data structure for diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index 2bd704b4..50592e25 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -27,6 +27,12 @@ namespace Gudhi { +#ifdef GUDHI_USE_EIGEN3 +const int _GUDHI_USE_EIGEN3 = 1; +#else +const int _GUDHI_USE_EIGEN3 = 0; +#endif + template class Simplex_tree_interface : public Simplex_tree { public: @@ -191,7 +197,9 @@ class Simplex_tree_interface : public Simplex_tree { } return collapsed_stree_ptr; #else - return this; + // If no Eigen3, return a copy, as it will be deleted in pyx + Simplex_tree_interface* collapsed_stree_ptr = new Simplex_tree_interface(*this); + return collapsed_stree_ptr; #endif } diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 3b23fa0b..15b472ee 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -8,7 +8,7 @@ - YYYY/MM Author: Description of the modification """ -from gudhi import SimplexTree +from gudhi import SimplexTree, simplex_tree import pytest __author__ = "Vincent Rouvreau" @@ -353,11 +353,15 @@ def test_collapse_edges(): assert st.num_simplices() == 10 + # If no Eigen3, collapse_edges just return a copy, no action. Maybe it would require some user warning st.collapse_edges() - assert st.num_simplices() == 9 - assert st.find([1, 3]) == False - for simplex in st.get_skeleton(0): - assert simplex[1] == 1. + if simplex_tree.__GUDHI_USE_EIGEN3: + assert st.num_simplices() == 9 + assert st.find([1, 3]) == False + for simplex in st.get_skeleton(0): + assert simplex[1] == 1. + else: + assert st.num_simplices() == 10 def test_reset_filtration(): st = SimplexTree() -- cgit v1.2.3 From 957da77f9484972ce34d0415502887f92080878e Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Fri, 11 Dec 2020 09:38:12 +0100 Subject: code review: GUDHI_USE_EIGEN3 generated by CMake in __init__.py as suggested and roll back the other solution --- src/python/CMakeLists.txt | 2 ++ src/python/gudhi/__init__.py.in | 4 ++++ src/python/gudhi/simplex_tree.pxd | 2 -- src/python/gudhi/simplex_tree.pyx | 3 --- src/python/include/Simplex_tree_interface.h | 6 ------ src/python/test/test_simplex_tree.py | 4 ++-- 6 files changed, 8 insertions(+), 13 deletions(-) (limited to 'src/python/test') diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 45c89609..5a245aac 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -129,11 +129,13 @@ if(PYTHONINTERP_FOUND) set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DDEBUG_TRACES', ") endif() + set(GUDHI_USE_EIGEN3 "False") if (EIGEN3_FOUND) add_gudhi_debug_info("Eigen3 version ${EIGEN3_VERSION}") # No problem, even if no CGAL found set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DCGAL_EIGEN3_ENABLED', ") set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DGUDHI_USE_EIGEN3', ") + set(GUDHI_USE_EIGEN3 "True") endif (EIGEN3_FOUND) set(GUDHI_CYTHON_MODULES "${GUDHI_CYTHON_MODULES}'off_reader', ") diff --git a/src/python/gudhi/__init__.py.in b/src/python/gudhi/__init__.py.in index 79e12fbc..3043201a 100644 --- a/src/python/gudhi/__init__.py.in +++ b/src/python/gudhi/__init__.py.in @@ -23,6 +23,10 @@ __all__ = [@GUDHI_PYTHON_MODULES@ @GUDHI_PYTHON_MODULES_EXTRA@] __available_modules = '' __missing_modules = '' +# For unitary tests purpose +# could use "if 'collapse_edges' in gudhi.__all__" when collapse edges will have a python module +__GUDHI_USE_EIGEN3 = @GUDHI_USE_EIGEN3@ + # Try to import * from gudhi.__module_name for default modules. # Extra modules require an explicit import by the user (mostly because of # unusual dependencies, but also to avoid cluttering namespace gudhi and diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index 283830ff..3c4cbed3 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -74,8 +74,6 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": Simplex_tree_skeleton_iterator get_skeleton_iterator_begin(int dimension) nogil Simplex_tree_skeleton_iterator get_skeleton_iterator_end(int dimension) nogil pair[Simplex_tree_boundary_iterator, Simplex_tree_boundary_iterator] get_boundary_iterators(vector[int] simplex) nogil except + - - cdef int _GUDHI_USE_EIGEN3 cdef extern from "Persistent_cohomology_interface.h" namespace "Gudhi": cdef cppclass Simplex_tree_persistence_interface "Gudhi::Persistent_cohomology_interface>": diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index eb44c0fc..7d6ab89a 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -17,9 +17,6 @@ __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2016 Inria" __license__ = "MIT" -# For unitary tests purpose -__GUDHI_USE_EIGEN3 = _GUDHI_USE_EIGEN3 - # SimplexTree python interface cdef class SimplexTree: """The simplex tree is an efficient and flexible data structure for diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index 50592e25..82444609 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -27,12 +27,6 @@ namespace Gudhi { -#ifdef GUDHI_USE_EIGEN3 -const int _GUDHI_USE_EIGEN3 = 1; -#else -const int _GUDHI_USE_EIGEN3 = 0; -#endif - template class Simplex_tree_interface : public Simplex_tree { public: diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 15b472ee..0c072baa 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -8,7 +8,7 @@ - YYYY/MM Author: Description of the modification """ -from gudhi import SimplexTree, simplex_tree +from gudhi import SimplexTree, __GUDHI_USE_EIGEN3 import pytest __author__ = "Vincent Rouvreau" @@ -355,7 +355,7 @@ def test_collapse_edges(): # If no Eigen3, collapse_edges just return a copy, no action. Maybe it would require some user warning st.collapse_edges() - if simplex_tree.__GUDHI_USE_EIGEN3: + if __GUDHI_USE_EIGEN3: assert st.num_simplices() == 9 assert st.find([1, 3]) == False for simplex in st.get_skeleton(0): -- cgit v1.2.3 From 40e0976e8dae27f0e30f3c9df1fd7de1a7343948 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Fri, 11 Dec 2020 10:31:09 +0100 Subject: code review: throw an eception if collapse_edges when no Eigen3 --- src/python/gudhi/simplex_tree.pxd | 2 +- src/python/gudhi/simplex_tree.pyx | 4 ++-- src/python/include/Simplex_tree_interface.h | 4 +--- src/python/test/test_simplex_tree.py | 7 ++++--- 4 files changed, 8 insertions(+), 9 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index 3c4cbed3..000323af 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -63,7 +63,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": bool make_filtration_non_decreasing() nogil void compute_extended_filtration() nogil vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil - Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil + Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil except + void reset_filtration(double filtration, int dimension) nogil # Iterators over Simplex tree pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) nogil diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 7d6ab89a..665d41e6 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -628,8 +628,8 @@ cdef class SimplexTree: :param nb_iterations: The number of edge collapse iterations to perform. Default is 1. :type nb_iterations: int - :note: collapse_edges function requires `Eigen `_ >= 3.1.0, otherwise no action is - performed. + :note: collapse_edges method requires `Eigen `_ >= 3.1.0 and an exception is thrown + if this method is not available. """ # Backup old pointer cdef Simplex_tree_interface_full_featured* ptr = self.get_ptr() diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h index 82444609..629f6083 100644 --- a/src/python/include/Simplex_tree_interface.h +++ b/src/python/include/Simplex_tree_interface.h @@ -191,9 +191,7 @@ class Simplex_tree_interface : public Simplex_tree { } return collapsed_stree_ptr; #else - // If no Eigen3, return a copy, as it will be deleted in pyx - Simplex_tree_interface* collapsed_stree_ptr = new Simplex_tree_interface(*this); - return collapsed_stree_ptr; + throw std::runtime_error("Unable to collapse edges as it requires Eigen3 >= 3.1.0."); #endif } diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 0c072baa..a3eacaa9 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -353,15 +353,16 @@ def test_collapse_edges(): assert st.num_simplices() == 10 - # If no Eigen3, collapse_edges just return a copy, no action. Maybe it would require some user warning - st.collapse_edges() if __GUDHI_USE_EIGEN3: + st.collapse_edges() assert st.num_simplices() == 9 assert st.find([1, 3]) == False for simplex in st.get_skeleton(0): assert simplex[1] == 1. else: - assert st.num_simplices() == 10 + # If no Eigen3, collapse_edges throws an exception + with pytest.raises(RuntimeError): + st.collapse_edges() def test_reset_filtration(): st = SimplexTree() -- cgit v1.2.3 From 0afc650917ddf9fc4cf95fd86e0b6408f64a465d Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Mon, 11 Jan 2021 11:29:20 +0100 Subject: Remove sphinx doc test for atol as points order can be inverted and add it in a UT but sorted --- src/python/gudhi/representations/vector_methods.py | 14 +++++++------- src/python/test/test_representations.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index cdcb1fde..d4449e7d 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -606,16 +606,16 @@ class Atol(BaseEstimator, TransformerMixin): >>> c = np.array([[3, 2, -1], [1, 2, -1]]) >>> atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2, random_state=202006)) >>> atol_vectoriser.fit(X=[a, b, c]).centers - array([[ 2. , 0.66666667, 3.33333333], - [ 2.6 , 2.8 , -0.4 ]]) + >>> # array([[ 2. , 0.66666667, 3.33333333], + >>> # [ 2.6 , 2.8 , -0.4 ]]) >>> atol_vectoriser(a) - array([1.18168665, 0.42375966]) + >>> # array([1.18168665, 0.42375966]) >>> atol_vectoriser(c) - array([0.02062512, 1.25157463]) + >>> # array([0.02062512, 1.25157463]) >>> atol_vectoriser.transform(X=[a, b, c]) - array([[1.18168665, 0.42375966], - [0.29861028, 1.06330156], - [0.02062512, 1.25157463]]) + >>> # array([[1.18168665, 0.42375966], + >>> # [0.29861028, 1.06330156], + >>> # [0.02062512, 1.25157463]]) """ def __init__(self, quantiser, weighting_method="cloud", contrast="gaussian"): """ diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 43c914f3..1c8f8cdb 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -46,6 +46,24 @@ def test_multiple(): assert d1 == pytest.approx(d2, rel=0.02) +# Test sorted values as points order can be inverted, and sorted test is not documentation-friendly +def test_atol_doc(): + a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]]) + b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]]) + c = np.array([[3, 2, -1], [1, 2, -1]]) + + atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2, random_state=202006)) + assert np.sort(atol_vectoriser.fit(X=[a, b, c]).centers, axis=0) == \ + pytest.approx(np.array([[2. , 0.66666667, -0.4], \ + [2.6, 2.8 , 3.33333333]])) + assert np.sort(atol_vectoriser(a)) == pytest.approx(np.array([0.42375966, 1.18168665])) + assert np.sort(atol_vectoriser(c)) == pytest.approx(np.array([0.02062512, 1.25157463])) + assert np.sort(atol_vectoriser.transform(X=[a, b, c]), axis=0) == \ + pytest.approx(np.array([[0.02062512, 0.42375966], \ + [0.29861028, 1.06330156], \ + [1.18168665, 1.25157463]])) + + def test_dummy_atol(): a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]]) b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]]) -- cgit v1.2.3 From 60907b0104a2807667f175d9a8a328fd3f7f4ec8 Mon Sep 17 00:00:00 2001 From: ROUVREAU Vincent Date: Mon, 11 Jan 2021 16:25:18 +0100 Subject: Ignore doctest for atol doc. Rewrite unitary test for atol doc. To be synchronized --- src/python/gudhi/representations/vector_methods.py | 9 ++++---- src/python/test/test_representations.py | 26 ++++++++++++++-------- 2 files changed, 22 insertions(+), 13 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/representations/vector_methods.py b/src/python/gudhi/representations/vector_methods.py index 5ec2abd0..84bc99a2 100644 --- a/src/python/gudhi/representations/vector_methods.py +++ b/src/python/gudhi/representations/vector_methods.py @@ -605,18 +605,19 @@ class Atol(BaseEstimator, TransformerMixin): >>> b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]]) >>> c = np.array([[3, 2, -1], [1, 2, -1]]) >>> atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2, random_state=202006)) - >>> atol_vectoriser.fit(X=[a, b, c]).centers #doctest: +SKIP + >>> atol_vectoriser.fit(X=[a, b, c]).centers # doctest: +SKIP >>> # array([[ 2. , 0.66666667, 3.33333333], >>> # [ 2.6 , 2.8 , -0.4 ]]) >>> atol_vectoriser(a) - >>> # array([1.18168665, 0.42375966]) #doctest: +SKIP + >>> # array([1.18168665, 0.42375966]) # doctest: +SKIP >>> atol_vectoriser(c) - >>> # array([0.02062512, 1.25157463]) #doctest: +SKIP - >>> atol_vectoriser.transform(X=[a, b, c]) #doctest: +SKIP + >>> # array([0.02062512, 1.25157463]) # doctest: +SKIP + >>> atol_vectoriser.transform(X=[a, b, c]) # doctest: +SKIP >>> # array([[1.18168665, 0.42375966], >>> # [0.29861028, 1.06330156], >>> # [0.02062512, 1.25157463]]) """ + # Note the example above must be up to date with the one in tests called test_atol_doc def __init__(self, quantiser, weighting_method="cloud", contrast="gaussian"): """ Constructor for the Atol measure vectorisation class. diff --git a/src/python/test/test_representations.py b/src/python/test/test_representations.py index 1c8f8cdb..cda1a15b 100755 --- a/src/python/test/test_representations.py +++ b/src/python/test/test_representations.py @@ -47,21 +47,29 @@ def test_multiple(): # Test sorted values as points order can be inverted, and sorted test is not documentation-friendly +# Note the test below must be up to date with the Atol class documentation def test_atol_doc(): a = np.array([[1, 2, 4], [1, 4, 0], [1, 0, 4]]) b = np.array([[4, 2, 0], [4, 4, 0], [4, 0, 2]]) c = np.array([[3, 2, -1], [1, 2, -1]]) atol_vectoriser = Atol(quantiser=KMeans(n_clusters=2, random_state=202006)) - assert np.sort(atol_vectoriser.fit(X=[a, b, c]).centers, axis=0) == \ - pytest.approx(np.array([[2. , 0.66666667, -0.4], \ - [2.6, 2.8 , 3.33333333]])) - assert np.sort(atol_vectoriser(a)) == pytest.approx(np.array([0.42375966, 1.18168665])) - assert np.sort(atol_vectoriser(c)) == pytest.approx(np.array([0.02062512, 1.25157463])) - assert np.sort(atol_vectoriser.transform(X=[a, b, c]), axis=0) == \ - pytest.approx(np.array([[0.02062512, 0.42375966], \ - [0.29861028, 1.06330156], \ - [1.18168665, 1.25157463]])) + # Atol will do + # X = np.concatenate([a,b,c]) + # kmeans = KMeans(n_clusters=2, random_state=202006).fit(X) + # kmeans.labels_ will be : array([1, 0, 1, 0, 0, 1, 0, 0]) + first_cluster = np.asarray([a[0], a[2], b[2]]) + second_cluster = np.asarray([a[1], b[0], b[2], c[0], c[1]]) + + # Check the center of the first_cluster and second_cluster are in Atol centers + centers = atol_vectoriser.fit(X=[a, b, c]).centers + np.isclose(centers, first_cluster.mean(axis=0)).all(1).any() + np.isclose(centers, second_cluster.mean(axis=0)).all(1).any() + + vectorization = atol_vectoriser.transform(X=[a, b, c]) + assert np.allclose(vectorization[0], atol_vectoriser(a)) + assert np.allclose(vectorization[1], atol_vectoriser(b)) + assert np.allclose(vectorization[2], atol_vectoriser(c)) def test_dummy_atol(): -- cgit v1.2.3