summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMathieuCarriere <mathieu.carriere3@gmail.com>2020-03-17 12:17:43 -0400
committerMathieuCarriere <mathieu.carriere3@gmail.com>2020-03-17 12:17:43 -0400
commit50427c9fe5c63f3bfe27270db0ea26480e73cb1a (patch)
treed4f867c6076ad4ae425a77c4346b370921e6251f /src
parent58d923b13afb9b18a2d5b028c6575baee691d182 (diff)
parent5fdbad3fdd350edc3a5340110f4c5332c73517b3 (diff)
fix conflict
Diffstat (limited to 'src')
-rw-r--r--src/Simplex_tree/include/gudhi/Simplex_tree.h75
-rw-r--r--src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h12
-rw-r--r--src/Simplex_tree/test/simplex_tree_unit_test.cpp36
-rw-r--r--src/python/doc/wasserstein_distance_user.rst43
-rw-r--r--src/python/gudhi/wasserstein.py57
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py33
6 files changed, 216 insertions, 40 deletions
diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree.h b/src/Simplex_tree/include/gudhi/Simplex_tree.h
index f661f687..1c06e7cb 100644
--- a/src/Simplex_tree/include/gudhi/Simplex_tree.h
+++ b/src/Simplex_tree/include/gudhi/Simplex_tree.h
@@ -24,6 +24,7 @@
#include <boost/iterator/transform_iterator.hpp>
#include <boost/graph/adjacency_list.hpp>
#include <boost/range/adaptor/reversed.hpp>
+#include <boost/container/static_vector.hpp>
#ifdef GUDHI_USE_TBB
#include <tbb/parallel_sort.h>
@@ -248,8 +249,8 @@ class Simplex_tree {
* which is consequenlty
* equal to \f$(-1)^{\text{dim} \sigma}\f$ the canonical orientation on the simplex.
*/
- Simplex_vertex_range simplex_vertex_range(Simplex_handle sh) {
- assert(sh != null_simplex()); // Empty simplex
+ Simplex_vertex_range simplex_vertex_range(Simplex_handle sh) const {
+ GUDHI_CHECK(sh != null_simplex(), "empty simplex");
return Simplex_vertex_range(Simplex_vertex_iterator(this, sh),
Simplex_vertex_iterator(this));
}
@@ -452,6 +453,15 @@ class Simplex_tree {
return true;
}
+ /** \brief Returns the filtration value of a simplex.
+ *
+ * Same as `filtration()`, but does not handle `null_simplex()`.
+ */
+ static Filtration_value filtration_(Simplex_handle sh) {
+ GUDHI_CHECK (sh != null_simplex(), "null simplex");
+ return sh->second.filtration();
+ }
+
public:
/** \brief Returns the key associated to a simplex.
*
@@ -829,7 +839,7 @@ class Simplex_tree {
/** Returns the Siblings containing a simplex.*/
template<class SimplexHandle>
- Siblings* self_siblings(SimplexHandle sh) {
+ static Siblings* self_siblings(SimplexHandle sh) {
if (sh->second.children()->parent() == sh->first)
return sh->second.children()->oncles();
else
@@ -1589,6 +1599,65 @@ class Simplex_tree {
this->make_filtration_non_decreasing();
}
+ /** \brief Returns a vertex of `sh` that has the same filtration value as `sh` if it exists, and `null_vertex()` otherwise.
+ *
+ * For a lower-star filtration built with `make_filtration_non_decreasing()`, this is a way to invert the process and find out which vertex had its filtration value propagated to `sh`.
+ * If several vertices have the same filtration value, the one it returns is arbitrary. */
+ Vertex_handle vertex_with_same_filtration(Simplex_handle sh) {
+ auto filt = filtration_(sh);
+ for(auto v : simplex_vertex_range(sh))
+ if(filtration_(find_vertex(v)) == filt)
+ return v;
+ return null_vertex();
+ }
+
+ /** \brief Returns an edge of `sh` that has the same filtration value as `sh` if it exists, and `null_simplex()` otherwise.
+ *
+ * For a flag-complex built with `expansion()`, this is a way to invert the process and find out which edge had its filtration value propagated to `sh`.
+ * If several edges have the same filtration value, the one it returns is arbitrary.
+ *
+ * \pre `sh` must have dimension at least 1. */
+ Simplex_handle edge_with_same_filtration(Simplex_handle sh) {
+ // See issue #251 for potential speed improvements.
+ auto&& vertices = simplex_vertex_range(sh); // vertices in decreasing order
+ auto end = std::end(vertices);
+ auto vi = std::begin(vertices);
+ GUDHI_CHECK(vi != end, "empty simplex");
+ auto v0 = *vi;
+ ++vi;
+ GUDHI_CHECK(vi != end, "simplex of dimension 0");
+ if(std::next(vi) == end) return sh; // shortcut for dimension 1
+ boost::container::static_vector<Vertex_handle, 40> suffix;
+ suffix.push_back(v0);
+ auto filt = filtration_(sh);
+ do
+ {
+ Vertex_handle v = *vi;
+ auto&& children1 = find_vertex(v)->second.children()->members_;
+ for(auto w : suffix){
+ // Can we take advantage of the fact that suffix is ordered?
+ Simplex_handle s = children1.find(w);
+ if(filtration_(s) == filt)
+ return s;
+ }
+ suffix.push_back(v);
+ }
+ while(++vi != end);
+ return null_simplex();
+ }
+
+ /** \brief Returns a minimal face of `sh` that has the same filtration value as `sh`.
+ *
+ * For a filtration built with `make_filtration_non_decreasing()`, this is a way to invert the process and find out which simplex had its filtration value propagated to `sh`.
+ * If several minimal (for inclusion) simplices have the same filtration value, the one it returns is arbitrary, and it is not guaranteed to be the one with smallest dimension. */
+ Simplex_handle minimal_simplex_with_same_filtration(Simplex_handle sh) {
+ auto filt = filtration_(sh);
+ // Naive implementation, it can be sped up.
+ for(auto b : boundary_simplex_range(sh))
+ if(filtration_(b) == filt)
+ return minimal_simplex_with_same_filtration(b);
+ return sh; // None of its faces has the same filtration.
+ }
private:
Vertex_handle null_vertex_;
diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h
index efccf2f2..9007b6bd 100644
--- a/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h
+++ b/src/Simplex_tree/include/gudhi/Simplex_tree/Simplex_tree_iterators.h
@@ -15,9 +15,7 @@
#include <boost/iterator/iterator_facade.hpp>
#include <boost/version.hpp>
-#if BOOST_VERSION >= 105600
-# include <boost/container/static_vector.hpp>
-#endif
+#include <boost/container/static_vector.hpp>
#include <vector>
@@ -42,13 +40,13 @@ class Simplex_tree_simplex_vertex_iterator : public boost::iterator_facade<
typedef typename SimplexTree::Siblings Siblings;
typedef typename SimplexTree::Vertex_handle Vertex_handle;
- explicit Simplex_tree_simplex_vertex_iterator(SimplexTree * st)
+ explicit Simplex_tree_simplex_vertex_iterator(SimplexTree const* st)
: // any end() iterator
sib_(nullptr),
v_(st->null_vertex()) {
}
- Simplex_tree_simplex_vertex_iterator(SimplexTree * st, Simplex_handle sh)
+ Simplex_tree_simplex_vertex_iterator(SimplexTree const* st, Simplex_handle sh)
: sib_(st->self_siblings(sh)),
v_(sh->first) {
}
@@ -166,15 +164,11 @@ class Simplex_tree_boundary_simplex_iterator : public boost::iterator_facade<
// Most of the storage should be moved to the range, iterators should be light.
Vertex_handle last_; // last vertex of the simplex
Vertex_handle next_; // next vertex to push in suffix_
-#if BOOST_VERSION >= 105600
// 40 seems a conservative bound on the dimension of a Simplex_tree for now,
// as it would not fit on the biggest hard-drive.
boost::container::static_vector<Vertex_handle, 40> suffix_;
// static_vector still has some overhead compared to a trivial hand-made
// version using std::aligned_storage, or compared to making suffix_ static.
-#else
- std::vector<Vertex_handle> suffix_;
-#endif
Siblings * sib_; // where the next search will start from
Simplex_handle sh_; // current Simplex_handle in the boundary
SimplexTree * st_; // simplex containing the simplicial complex
diff --git a/src/Simplex_tree/test/simplex_tree_unit_test.cpp b/src/Simplex_tree/test/simplex_tree_unit_test.cpp
index e739ad0a..3872c4ca 100644
--- a/src/Simplex_tree/test/simplex_tree_unit_test.cpp
+++ b/src/Simplex_tree/test/simplex_tree_unit_test.cpp
@@ -902,5 +902,41 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(insert_duplicated_vertices, typeST, list_of_tested
<< " - num_simplices = " << st.num_simplices() << std::endl;
BOOST_CHECK(st.dimension() == 1);
BOOST_CHECK(st.num_simplices() == st.num_vertices() + 1);
+}
+BOOST_AUTO_TEST_CASE_TEMPLATE(generators, typeST, list_of_tested_variants) {
+ std::cout << "********************************************************************" << std::endl;
+ std::cout << "TEST FIND GENERATORS" << std::endl;
+ {
+ typeST st;
+ st.insert_simplex_and_subfaces({0,1,2,3,4,5,6},0);
+ st.assign_filtration(st.find({0,2,4}), 10);
+ st.assign_filtration(st.find({1,5}), 20);
+ st.assign_filtration(st.find({1,2,4}), 30);
+ st.assign_filtration(st.find({3}), 5);
+ st.make_filtration_non_decreasing();
+ BOOST_CHECK(st.filtration(st.find({1,2}))==0);
+ BOOST_CHECK(st.filtration(st.find({0,1,2,3,4}))==30);
+ BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,1,2,3,4,5}))==st.find({1,2,4}));
+ BOOST_CHECK(st.minimal_simplex_with_same_filtration(st.find({0,2,3}))==st.find({3}));
+ auto s=st.minimal_simplex_with_same_filtration(st.find({0,2,6}));
+ BOOST_CHECK(s==st.find({0})||s==st.find({2})||s==st.find({6}));
+ BOOST_CHECK(st.vertex_with_same_filtration(st.find({2}))==2);
+ BOOST_CHECK(st.vertex_with_same_filtration(st.find({1,5}))==st.null_vertex());
+ BOOST_CHECK(st.vertex_with_same_filtration(st.find({5,6}))>=5);
+ }
+ {
+ typeST st;
+ st.insert_simplex_and_subfaces({0,1}, 8);
+ st.insert_simplex_and_subfaces({0,2}, 10);
+ st.insert_simplex_and_subfaces({3,4}, 6);
+ st.insert_simplex_and_subfaces({1,2}, 5);
+ st.insert_simplex_and_subfaces({1,5}, 4);
+ st.insert_simplex_and_subfaces({0,5}, 3);
+ st.insert_simplex_and_subfaces({2,5}, 2);
+ st.insert_simplex_and_subfaces({1,3}, 9);
+ st.expansion(50);
+ BOOST_CHECK(st.edge_with_same_filtration(st.find({0,1,2,5}))==st.find({0,2}));
+ BOOST_CHECK(st.edge_with_same_filtration(st.find({1,5}))==st.find({1,5}));
+ }
}
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index 94b454e2..a9b21fa5 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -36,10 +36,10 @@ Note that persistence diagrams must be submitted as (n x 2) numpy arrays and mus
import gudhi.wasserstein
import numpy as np
- diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
- diag2 = np.array([[2.8, 4.45],[9.5, 14.1]])
+ dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
+ dgm2 = np.array([[2.8, 4.45],[9.5, 14.1]])
- message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(diag1, diag2, order=1., internal_p=2.)
+ message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, order=1., internal_p=2.)
print(message)
The output is:
@@ -47,3 +47,40 @@ The output is:
.. testoutput::
Wasserstein distance value = 1.45
+
+We can also have access to the optimal matching by letting `matching=True`.
+It is encoded as a list of indices (i,j), meaning that the i-th point in X
+is mapped to the j-th point in Y.
+An index of -1 represents the diagonal.
+
+.. testcode::
+
+ import gudhi.wasserstein
+ import numpy as np
+
+ dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
+ dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]])
+ cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, matching=True, order=1, internal_p=2)
+
+ message_cost = "Wasserstein distance value = %.2f" %cost
+ print(message_cost)
+ dgm1_to_diagonal = matchings[matchings[:,1] == -1, 0]
+ dgm2_to_diagonal = matchings[matchings[:,0] == -1, 1]
+ off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0)
+
+ for i,j in off_diagonal_match:
+ print("point %s in dgm1 is matched to point %s in dgm2" %(i,j))
+ for i in dgm1_to_diagonal:
+ print("point %s in dgm1 is matched to the diagonal" %i)
+ for j in dgm2_to_diagonal:
+ print("point %s in dgm2 is matched to the diagonal" %j)
+
+The output is:
+
+.. testoutput::
+
+ Wasserstein distance value = 2.15
+ point 0 in dgm1 is matched to point 0 in dgm2
+ point 1 in dgm1 is matched to point 2 in dgm2
+ point 2 in dgm1 is matched to the diagonal
+ point 1 in dgm2 is matched to the diagonal
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
index 13102094..3dd993f9 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein.py
@@ -30,8 +30,10 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.):
:param order: exponent for the Wasserstein metric.
:param internal_p: Ground metric (i.e. norm L^p).
:returns: (n+1) x (m+1) np.array encoding the cost matrix C.
- For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal.
- note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal).
+ For 0 <= i < n, 0 <= j < m, C[i,j] encodes the distance between X[i] and Y[j],
+ while C[i, m] (resp. C[n, j]) encodes the distance (to the p) between X[i] (resp Y[j])
+ and its orthogonal projection onto the diagonal.
+ note also that C[n, m] = 0 (it costs nothing to move from the diagonal to the diagonal).
'''
Xdiag = _proj_on_diag(X)
Ydiag = _proj_on_diag(Y)
@@ -62,14 +64,20 @@ def _perstot(X, order, internal_p):
return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order)
-def wasserstein_distance(X, Y, order=2., internal_p=2.):
+def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
'''
- :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate).
+ :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
+ (i.e. with infinite coordinate).
:param Y: (m x 2) numpy.array encoding the second diagram.
+ :param matching: if True, computes and returns the optimal matching between X and Y, encoded as
+ a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
+ the j-th point in Y, with the convention (-1) represents the diagonal.
:param order: exponent for Wasserstein; Default value is 2.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
- :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric.
- :rtype: float
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
+ Default value is 2 (Euclidean norm).
+ :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
+ respect to the internal_p-norm as ground metric.
+ If matching is set to True, also returns the optimal matching between X and Y.
'''
n = len(X)
m = len(Y)
@@ -77,21 +85,40 @@ def wasserstein_distance(X, Y, order=2., internal_p=2.):
# handle empty diagrams
if X.size == 0:
if Y.size == 0:
- return 0.
+ if not matching:
+ return 0.
+ else:
+ return 0., np.array([])
else:
- return _perstot(Y, order, internal_p)
+ if not matching:
+ return _perstot(Y, order, internal_p)
+ else:
+ return _perstot(Y, order, internal_p), np.array([[-1, j] for j in range(m)])
elif Y.size == 0:
- return _perstot(X, order, internal_p)
+ if not matching:
+ return _perstot(X, order, internal_p)
+ else:
+ return _perstot(X, order, internal_p), np.array([[i, -1] for i in range(n)])
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
- a = np.full(n+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
- a[-1] = a[-1] * m # normalized so that we have a probability measure, required by POT
- b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
- b[-1] = b[-1] * n # so that we have a probability measure, required by POT
+ a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
+ a[-1] = m
+ b = np.ones(m+1) # weight vector of the input diagram. Uniform here.
+ b[-1] = n
+
+ if matching:
+ P = ot.emd(a=a,b=b,M=M, numItermax=2000000)
+ ot_cost = np.sum(np.multiply(P,M))
+ P[-1, -1] = 0 # Remove matching corresponding to the diagonal
+ match = np.argwhere(P)
+ # Now we turn to -1 points encoding the diagonal
+ match[:,0][match[:,0] >= n] = -1
+ match[:,1][match[:,1] >= m] = -1
+ return ot_cost ** (1./order) , match
# Comptuation of the otcost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
- ot_cost = (n+m) * ot.emd2(a, b, M, numItermax=2000000)
+ ot_cost = ot.emd2(a, b, M, numItermax=2000000)
return ot_cost ** (1./order)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 6a6b217b..0d70e11a 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -17,7 +17,7 @@ __author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"
-def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
+def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
diag3 = np.array([[0, 2], [4, 6]])
@@ -51,14 +51,27 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5))
assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5))
- if(not test_infinity):
- return
+ if test_infinity:
+ diag5 = np.array([[0, 3], [4, np.inf]])
+ diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
- diag5 = np.array([[0, 3], [4, np.inf]])
- diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
+ assert wasserstein_distance(diag4, diag5) == np.inf
+ assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
+
+
+ if test_matching:
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
+ assert np.array_equal(match, [])
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert np.array_equal(match, [])
+ match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
+ assert np.array_equal(match , [[-1, 0], [-1, 1]])
+ match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert np.array_equal(match , [[0, -1], [1, -1]])
+ match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
+ assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
+
- assert wasserstein_distance(diag4, diag5) == np.inf
- assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
def hera_wrap(delta):
def fun(*kargs,**kwargs):
@@ -66,8 +79,8 @@ def hera_wrap(delta):
return fun
def test_wasserstein_distance_pot():
- _basic_wasserstein(pot, 1e-15, test_infinity=False)
+ _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(1e-12), 1e-12)
- _basic_wasserstein(hera_wrap(.1), .1)
+ _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(.1), .1, test_matching=False)