diff options
-rw-r--r-- | .github/test-requirements.txt | 3 | ||||
-rw-r--r-- | src/Simplex_tree/include/gudhi/Simplex_tree.h | 75 | ||||
-rw-r--r-- | src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h | 12 | ||||
-rw-r--r-- | src/Simplex_tree/test/simplex_tree_unit_test.cpp | 36 | ||||
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 43 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein.py | 57 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 33 |
7 files changed, 218 insertions, 41 deletions
diff --git a/.github/test-requirements.txt b/.github/test-requirements.txt index bd03f98e..18882792 100644 --- a/.github/test-requirements.txt +++ b/.github/test-requirements.txt @@ -5,4 +5,5 @@ sphinx-paramlinks matplotlib scipy scikit-learn -POT
\ No newline at end of file +POT +tensorflow
\ No newline at end of file diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree.h b/src/Simplex_tree/include/gudhi/Simplex_tree.h index f661f687..1c06e7cb 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree.h @@ -24,6 +24,7 @@ #include <boost/iterator/transform_iterator.hpp> #include <boost/graph/adjacency_list.hpp> #include <boost/range/adaptor/reversed.hpp> +#include <boost/container/static_vector.hpp> #ifdef GUDHI_USE_TBB #include <tbb/parallel_sort.h> @@ -248,8 +249,8 @@ class Simplex_tree { * which is consequenlty * equal to \f$(-1)^{\text{dim} \sigma}\f$ the canonical orientation on the simplex. */ - Simplex_vertex_range simplex_vertex_range(Simplex_handle sh) { - assert(sh != null_simplex()); // Empty simplex + Simplex_vertex_range simplex_vertex_range(Simplex_handle sh) const { + GUDHI_CHECK(sh != null_simplex(), "empty simplex"); return Simplex_vertex_range(Simplex_vertex_iterator(this, sh), Simplex_vertex_iterator(this)); } @@ -452,6 +453,15 @@ class Simplex_tree { return true; } + /** \brief Returns the filtration value of a simplex. + * + * Same as `filtration()`, but does not handle `null_simplex()`. + */ + static Filtration_value filtration_(Simplex_handle sh) { + GUDHI_CHECK (sh != null_simplex(), "null simplex"); + return sh->second.filtration(); + } + public: /** \brief Returns the key associated to a simplex. * @@ -829,7 +839,7 @@ class Simplex_tree { /** Returns the Siblings containing a simplex.*/ template<class SimplexHandle> - Siblings* self_siblings(SimplexHandle sh) { + static Siblings* self_siblings(SimplexHandle sh) { if (sh->second.children()->parent() == sh->first) return sh->second.children()->oncles(); else @@ -1589,6 +1599,65 @@ class Simplex_tree { this->make_filtration_non_decreasing(); } + /** \brief Returns a vertex of `sh` that has the same filtration value as `sh` if it exists, and `null_vertex()` otherwise. + * + * For a lower-star filtration built with `make_filtration_non_decreasing()`, this is a way to invert the process and find out which vertex had its filtration value propagated to `sh`. + * If several vertices have the same filtration value, the one it returns is arbitrary. */ + Vertex_handle vertex_with_same_filtration(Simplex_handle sh) { + auto filt = filtration_(sh); + for(auto v : simplex_vertex_range(sh)) + if(filtration_(find_vertex(v)) == filt) + return v; + return null_vertex(); + } + + /** \brief Returns an edge of `sh` that has the same filtration value as `sh` if it exists, and `null_simplex()` otherwise. + * + * For a flag-complex built with `expansion()`, this is a way to invert the process and find out which edge had its filtration value propagated to `sh`. + * If several edges have the same filtration value, the one it returns is arbitrary. + * + * \pre `sh` must have dimension at least 1. */ + Simplex_handle edge_with_same_filtration(Simplex_handle sh) { + // See issue #251 for potential speed improvements. + auto&& vertices = simplex_vertex_range(sh); // vertices in decreasing order + auto end = std::end(vertices); + auto vi = std::begin(vertices); + GUDHI_CHECK(vi != end, "empty simplex"); + auto v0 = *vi; + ++vi; + GUDHI_CHECK(vi != end, "simplex of dimension 0"); + if(std::next(vi) == end) return sh; // shortcut for dimension 1 + boost::container::static_vector<Vertex_handle, 40> suffix; + suffix.push_back(v0); + auto filt = filtration_(sh); + do + { + Vertex_handle v = *vi; + auto&& children1 = find_vertex(v)->second.children()->members_; + for(auto w : suffix){ + // Can we take advantage of the fact that suffix is ordered? + Simplex_handle s = children1.find(w); + if(filtration_(s) == filt) + return s; + } + suffix.push_back(v); + } + while(++vi != end); + return null_simplex(); + } + + /** \brief Returns a minimal face of `sh` that has the same filtration value as `sh`. + * + * For a filtration built with `make_filtration_non_decreasing()`, this is a way to invert the process and find out which simplex had its filtration value propagated to `sh`. + * If several minimal (for inclusion) simplices have the same filtration value, the one it returns is arbitrary, and it is not guaranteed to be the one with smallest dimension. */ + Simplex_handle minimal_simplex_with_same_filtration(Simplex_handle sh) { + auto filt = filtration_(sh); + // Naive implementation, it can be sped up. + for(auto b : boundary_simplex_range(sh)) + if(filtration_(b) == filt) + return minimal_simplex_with_same_filtration(b); + return sh; // None of its faces has the same filtration. + } private: Vertex_handle null_vertex_; diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h index efccf2f2..9007b6bd 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h @@ -15,9 +15,7 @@ #include <boost/iterator/iterator_facade.hpp> #include <boost/version.hpp> -#if BOOST_VERSION >= 105600 -# include <boost/container/static_vector.hpp> -#endif +#include <boost/container/static_vector.hpp> #include <vector> @@ -42,13 +40,13 @@ class Simplex_tree_simplex_vertex_iterator : public boost::iterator_facade< typedef typename SimplexTree::Siblings Siblings; typedef typename SimplexTree::Vertex_handle Vertex_handle; - explicit Simplex_tree_simplex_vertex_iterator(SimplexTree * st) + explicit Simplex_tree_simplex_vertex_iterator(SimplexTree const* st) : // any end() iterator sib_(nullptr), v_(st->null_vertex()) { } - Simplex_tree_simplex_vertex_iterator(SimplexTree * st, Simplex_handle sh) + Simplex_tree_simplex_vertex_iterator(SimplexTree const* st, Simplex_handle sh) : sib_(st->self_siblings(sh)), v_(sh->first) { } @@ -166,15 +164,11 @@ class Simplex_tree_boundary_simplex_iterator : public boost::iterator_facade< // Most of the storage should be moved to the range, iterators should be light. Vertex_handle last_; // last vertex of the simplex Vertex_handle next_; // next vertex to push in suffix_ -#if BOOST_VERSION >= 105600 // 40 seems a conservative bound on the dimension of a Simplex_tree for now, // as it would not fit on the biggest hard-drive. boost::container::static_vector<Vertex_handle, 40> suffix_; // static_vector still has some overhead compared to a trivial hand-made // version using std::aligned_storage, or compared to making suffix_ static. -#else - std::vector<Vertex_handle> suffix_; -#endif Siblings * sib_; // where the next search will start from Simplex_handle sh_; // current Simplex_handle in the boundary SimplexTree * st_; // simplex containing the simplicial complex diff --git a/src/Simplex_tree/test/simplex_tree_unit_test.cpp b/src/Simplex_tree/test/simplex_tree_unit_test.cpp index e739ad0a..3872c4ca 100644 --- a/src/Simplex_tree/test/simplex_tree_unit_test.cpp +++ b/src/Simplex_tree/test/simplex_tree_unit_test.cpp @@ -902,5 +902,41 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(insert_duplicated_vertices, typeST, list_of_tested << " - num_simplices = " << st.num_simplices() << std::endl; BOOST_CHECK(st.dimension() == 1); BOOST_CHECK(st.num_simplices() == st.num_vertices() + 1); +} +BOOST_AUTO_TEST_CASE_TEMPLATE(generators, typeST, list_of_tested_variants) { + std::cout << "********************************************************************" << std::endl; + std::cout << "TEST FIND GENERATORS" << std::endl; + { + typeST st; + st.insert_simplex_and_subfaces({0,1,2,3,4,5,6},0); + st.assign_filtration(st.find({0,2,4}), 10); + st.assign_filtration(st.find({1,5}), 20); + st.assign_filtration(st.find({1,2,4}), 30); + st.assign_filtration(st.find({3}), 5); + st.make_filtration_non_decreasing(); + BOOST_CHECK(st.filtration(st.find({1,2}))==0); + BOOST_CHECK(st.filtration(st.find({0,1,2,3,4}))==30); + BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,1,2,3,4,5}))==st.find({1,2,4})); + BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,2,3}))==st.find({3})); + auto s=st.minimal_simplex_with_same_filtration(st.find({0,2,6})); + BOOST_CHECK(s==st.find({0})||s==st.find({2})||s==st.find({6})); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({2}))==2); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({1,5}))==st.null_vertex()); + BOOST_CHECK(st.vertex_with_same_filtration(st.find({5,6}))>=5); + } + { + typeST st; + st.insert_simplex_and_subfaces({0,1}, 8); + st.insert_simplex_and_subfaces({0,2}, 10); + st.insert_simplex_and_subfaces({3,4}, 6); + st.insert_simplex_and_subfaces({1,2}, 5); + st.insert_simplex_and_subfaces({1,5}, 4); + st.insert_simplex_and_subfaces({0,5}, 3); + st.insert_simplex_and_subfaces({2,5}, 2); + st.insert_simplex_and_subfaces({1,3}, 9); + st.expansion(50); + BOOST_CHECK(st.edge_with_same_filtration(st.find({0,1,2,5}))==st.find({0,2})); + BOOST_CHECK(st.edge_with_same_filtration(st.find({1,5}))==st.find({1,5})); + } } diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 94b454e2..a9b21fa5 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -36,10 +36,10 @@ Note that persistence diagrams must be submitted as (n x 2) numpy arrays and mus import gudhi.wasserstein import numpy as np - diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) - diag2 = np.array([[2.8, 4.45],[9.5, 14.1]]) + dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + dgm2 = np.array([[2.8, 4.45],[9.5, 14.1]]) - message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(diag1, diag2, order=1., internal_p=2.) + message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, order=1., internal_p=2.) print(message) The output is: @@ -47,3 +47,40 @@ The output is: .. testoutput:: Wasserstein distance value = 1.45 + +We can also have access to the optimal matching by letting `matching=True`. +It is encoded as a list of indices (i,j), meaning that the i-th point in X +is mapped to the j-th point in Y. +An index of -1 represents the diagonal. + +.. testcode:: + + import gudhi.wasserstein + import numpy as np + + dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]]) + cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, matching=True, order=1, internal_p=2) + + message_cost = "Wasserstein distance value = %.2f" %cost + print(message_cost) + dgm1_to_diagonal = matchings[matchings[:,1] == -1, 0] + dgm2_to_diagonal = matchings[matchings[:,0] == -1, 1] + off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0) + + for i,j in off_diagonal_match: + print("point %s in dgm1 is matched to point %s in dgm2" %(i,j)) + for i in dgm1_to_diagonal: + print("point %s in dgm1 is matched to the diagonal" %i) + for j in dgm2_to_diagonal: + print("point %s in dgm2 is matched to the diagonal" %j) + +The output is: + +.. testoutput:: + + Wasserstein distance value = 2.15 + point 0 in dgm1 is matched to point 0 in dgm2 + point 1 in dgm1 is matched to point 2 in dgm2 + point 2 in dgm1 is matched to the diagonal + point 1 in dgm2 is matched to the diagonal diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py index 13102094..3dd993f9 100644 --- a/src/python/gudhi/wasserstein.py +++ b/src/python/gudhi/wasserstein.py @@ -30,8 +30,10 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.): :param order: exponent for the Wasserstein metric. :param internal_p: Ground metric (i.e. norm L^p). :returns: (n+1) x (m+1) np.array encoding the cost matrix C. - For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal. - note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal). + For 0 <= i < n, 0 <= j < m, C[i,j] encodes the distance between X[i] and Y[j], + while C[i, m] (resp. C[n, j]) encodes the distance (to the p) between X[i] (resp Y[j]) + and its orthogonal projection onto the diagonal. + note also that C[n, m] = 0 (it costs nothing to move from the diagonal to the diagonal). ''' Xdiag = _proj_on_diag(X) Ydiag = _proj_on_diag(Y) @@ -62,14 +64,20 @@ def _perstot(X, order, internal_p): return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order) -def wasserstein_distance(X, Y, order=2., internal_p=2.): +def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.): ''' - :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). + :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points + (i.e. with infinite coordinate). :param Y: (m x 2) numpy.array encoding the second diagram. + :param matching: if True, computes and returns the optimal matching between X and Y, encoded as + a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to + the j-th point in Y, with the convention (-1) represents the diagonal. :param order: exponent for Wasserstein; Default value is 2. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). - :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric. - :rtype: float + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); + Default value is 2 (Euclidean norm). + :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with + respect to the internal_p-norm as ground metric. + If matching is set to True, also returns the optimal matching between X and Y. ''' n = len(X) m = len(Y) @@ -77,21 +85,40 @@ def wasserstein_distance(X, Y, order=2., internal_p=2.): # handle empty diagrams if X.size == 0: if Y.size == 0: - return 0. + if not matching: + return 0. + else: + return 0., np.array([]) else: - return _perstot(Y, order, internal_p) + if not matching: + return _perstot(Y, order, internal_p) + else: + return _perstot(Y, order, internal_p), np.array([[-1, j] for j in range(m)]) elif Y.size == 0: - return _perstot(X, order, internal_p) + if not matching: + return _perstot(X, order, internal_p) + else: + return _perstot(X, order, internal_p), np.array([[i, -1] for i in range(n)]) M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p) - a = np.full(n+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here. - a[-1] = a[-1] * m # normalized so that we have a probability measure, required by POT - b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here. - b[-1] = b[-1] * n # so that we have a probability measure, required by POT + a = np.ones(n+1) # weight vector of the input diagram. Uniform here. + a[-1] = m + b = np.ones(m+1) # weight vector of the input diagram. Uniform here. + b[-1] = n + + if matching: + P = ot.emd(a=a,b=b,M=M, numItermax=2000000) + ot_cost = np.sum(np.multiply(P,M)) + P[-1, -1] = 0 # Remove matching corresponding to the diagonal + match = np.argwhere(P) + # Now we turn to -1 points encoding the diagonal + match[:,0][match[:,0] >= n] = -1 + match[:,1][match[:,1] >= m] = -1 + return ot_cost ** (1./order) , match # Comptuation of the otcost using the ot.emd2 library. # Note: it is the Wasserstein distance to the power q. # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value? - ot_cost = (n+m) * ot.emd2(a, b, M, numItermax=2000000) + ot_cost = ot.emd2(a, b, M, numItermax=2000000) return ot_cost ** (1./order) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 6a6b217b..0d70e11a 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -17,7 +17,7 @@ __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" -def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): +def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) diag3 = np.array([[0, 2], [4, 6]]) @@ -51,14 +51,27 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5)) assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5)) - if(not test_infinity): - return + if test_infinity: + diag5 = np.array([[0, 3], [4, np.inf]]) + diag6 = np.array([[7, 8], [4, 6], [3, np.inf]]) - diag5 = np.array([[0, 3], [4, np.inf]]) - diag6 = np.array([[7, 8], [4, 6], [3, np.inf]]) + assert wasserstein_distance(diag4, diag5) == np.inf + assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) + + + if test_matching: + match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1] + assert np.array_equal(match, []) + match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] + assert np.array_equal(match, []) + match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1] + assert np.array_equal(match , [[-1, 0], [-1, 1]]) + match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] + assert np.array_equal(match , [[0, -1], [1, -1]]) + match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] + assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) + - assert wasserstein_distance(diag4, diag5) == np.inf - assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) def hera_wrap(delta): def fun(*kargs,**kwargs): @@ -66,8 +79,8 @@ def hera_wrap(delta): return fun def test_wasserstein_distance_pot(): - _basic_wasserstein(pot, 1e-15, test_infinity=False) + _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) def test_wasserstein_distance_hera(): - _basic_wasserstein(hera_wrap(1e-12), 1e-12) - _basic_wasserstein(hera_wrap(.1), .1) + _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False) + _basic_wasserstein(hera_wrap(.1), .1, test_matching=False) |