diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2022-11-16 00:20:04 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2022-11-16 14:06:50 +0100 |
commit | 8b658271dd38f1aaffbe94be8978cf5cea8ec7de (patch) | |
tree | 67e0c2af2d3586e7120d8cea88aee991ec662976 | |
parent | 04370bae13251d0bcce205f253fb758f91fdf207 (diff) |
Output matching from hera.wasserstein_distance
-rw-r--r-- | src/python/gudhi/bottleneck.cc | 7 | ||||
-rw-r--r-- | src/python/gudhi/hera/bottleneck.cc | 9 | ||||
-rw-r--r-- | src/python/gudhi/hera/wasserstein.cc | 127 | ||||
-rw-r--r-- | src/python/include/pybind11_diagram_utils.h | 25 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 13 |
5 files changed, 142 insertions, 39 deletions
diff --git a/src/python/gudhi/bottleneck.cc b/src/python/gudhi/bottleneck.cc index 8a3d669a..3a0fe473 100644 --- a/src/python/gudhi/bottleneck.cc +++ b/src/python/gudhi/bottleneck.cc @@ -12,6 +12,9 @@ #include <pybind11_diagram_utils.h> +// Indices are added internally in bottleneck_distance, they are not needed in the input. +static auto make_point(double x, double y, py::ssize_t) { return std::pair(x, y); }; + // For compatibility with older versions, we want to support e=None. // In C++17, the recommended way is std::optional<double>. double bottleneck(Dgm d1, Dgm d2, py::object epsilon) @@ -19,8 +22,8 @@ double bottleneck(Dgm d1, Dgm d2, py::object epsilon) double e = (std::numeric_limits<double>::min)(); if (!epsilon.is_none()) e = epsilon.cast<double>(); // I *think* the call to request() has to be before releasing the GIL. - auto diag1 = numpy_to_range_of_pairs(d1); - auto diag2 = numpy_to_range_of_pairs(d2); + auto diag1 = numpy_to_range_of_pairs(d1, make_point); + auto diag2 = numpy_to_range_of_pairs(d2, make_point); py::gil_scoped_release release; diff --git a/src/python/gudhi/hera/bottleneck.cc b/src/python/gudhi/hera/bottleneck.cc index ec461f7c..6b919b02 100644 --- a/src/python/gudhi/hera/bottleneck.cc +++ b/src/python/gudhi/hera/bottleneck.cc @@ -16,13 +16,16 @@ using py::ssize_t; #endif -#include <hera/bottleneck.h> // Hera +#include <hera/bottleneck.h> + +// Indices are added internally in bottleneck_distance, they are not needed in the input. +static auto make_point(double x, double y, py::ssize_t) { return std::pair(x, y); }; double bottleneck_distance(Dgm d1, Dgm d2, double delta) { // I *think* the call to request() has to be before releasing the GIL. - auto diag1 = numpy_to_range_of_pairs(d1); - auto diag2 = numpy_to_range_of_pairs(d2); + auto diag1 = numpy_to_range_of_pairs(d1, make_point); + auto diag2 = numpy_to_range_of_pairs(d2, make_point); py::gil_scoped_release release; diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc index 3516352e..8c731530 100644 --- a/src/python/gudhi/hera/wasserstein.cc +++ b/src/python/gudhi/hera/wasserstein.cc @@ -16,27 +16,118 @@ using py::ssize_t; #endif -#include <hera/wasserstein.h> // Hera +#include <hera/wasserstein.h> +#include <gudhi/Debug_utils.h> -double wasserstein_distance( +// Unlike bottleneck, for wasserstein, we need to add the index ourselves (if we want the matching) +static auto make_hera_point(double x, double y, py::ssize_t i) { return hera::DiagramPoint<double>(x, y, i); }; + +py::object wasserstein_distance( Dgm d1, Dgm d2, double wasserstein_power, double internal_p, - double delta) + double delta, bool return_matching) { // I *think* the call to request() has to be before releasing the GIL. - auto diag1 = numpy_to_range_of_pairs(d1); - auto diag2 = numpy_to_range_of_pairs(d2); - - py::gil_scoped_release release; - - hera::AuctionParams<double> params; - params.wasserstein_power = wasserstein_power; - // hera encodes infinity as -1... - if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>(); - params.internal_p = internal_p; - params.delta = delta; - // The extra parameters are purposely not exposed for now. - return hera::wasserstein_dist(diag1, diag2, params); + auto diag1 = numpy_to_range_of_pairs(d1, make_hera_point); + auto diag2 = numpy_to_range_of_pairs(d2, make_hera_point); + int n1 = boost::size(diag1); + int n2 = boost::size(diag2); + hera::AuctionResult<double> res; + double dist; + + { // No Python allowed in this section + py::gil_scoped_release release; + + hera::AuctionParams<double> params; + params.wasserstein_power = wasserstein_power; + // hera encodes infinity as -1... + if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>(); + params.internal_p = internal_p; + params.delta = delta; + if(return_matching) { + params.return_matching = true; + params.match_inf_points = true; + } + // The extra parameters are purposely not exposed for now. + res = hera::wasserstein_cost_detailed(diag1, diag2, params); + dist = std::pow(res.cost, 1./params.wasserstein_power); + } + + if(!return_matching) + return py::cast(dist); + + if(dist == std::numeric_limits<double>::infinity()) + return py::make_tuple(dist, py::none()); + + // bug in Hera, matching_a_to_b_ is empty if one diagram is empty or both diagrams contain the same points + if(res.matching_a_to_b_.size() == 0) { + if(n1 == 0) { // diag1 is empty + py::array_t<int> matching({{ n2, 2 }}, nullptr); + auto m = matching.mutable_unchecked(); + for(int j=0; j<n2; ++j){ + m(j, 0) = -1; + m(j, 1) = j; + } + return py::make_tuple(dist, matching); + } + if(n2 == 0) { // diag2 is empty + py::array_t<int> matching({{ n1, 2 }}, nullptr); + auto m = matching.mutable_unchecked(); + for(int i=0; i<n1; ++i){ + m(i, 0) = i; + m(i, 1) = -1; + } + return py::make_tuple(dist, matching); + } + // The only remaining case should be that the 2 diagrams are identical, but possibly shuffled + GUDHI_CHECK(n1==n2, "unexpected bug in Hera?"); + std::vector v1(boost::begin(diag1), boost::end(diag1)); + std::vector v2(boost::begin(diag2), boost::end(diag2)); + std::sort(v1.begin(), v1.end()); + std::sort(v2.begin(), v2.end()); + py::array_t<int> matching({{ n1, 2 }}, nullptr); + auto m = matching.mutable_unchecked(); + for(int i=0; i<n1; ++i){ + GUDHI_CHECK(v1[i][0]==v2[i][0] && v1[i][1]==v2[i][1], "unexpected bug in Hera?"); + m(i, 0) = v1[i].get_id(); + m(i, 1) = v2[i].get_id(); + } + return py::make_tuple(dist, matching); + + } + + // bug in Hera, diagonal points are ignored and don't appear in matching_a_to_b_ + for(auto p : diag1) + if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[id] = -id-1; } + for(auto p : diag2) + if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[-id-1] = id; } + + py::array_t<int> matching({{ n1 + n2, 2 }}, nullptr); + auto m = matching.mutable_unchecked(); + int cur = 0; + for(auto x : res.matching_a_to_b_){ + if(x.first < 0) { + if(x.second < 0) { + } else { + m(cur, 0) = -1; + m(cur, 1) = x.second; + ++cur; + } + } else { + if(x.second < 0) { + m(cur, 0) = x.first; + m(cur, 1) = -1; + ++cur; + } else { + m(cur, 0) = x.first; + m(cur, 1) = x.second; + ++cur; + } + } + } + // n1+n2 was too much, it only happens if everything matches to the diagonal, so we return matching[:cur,:] + py::array_t<int> ret({{ cur, 2 }}, {{ matching.strides()[0], matching.strides()[1] }}, matching.data(), matching); + return py::make_tuple(dist, ret); } PYBIND11_MODULE(wasserstein, m) { @@ -45,6 +136,7 @@ PYBIND11_MODULE(wasserstein, m) { py::arg("order") = 1, py::arg("internal_p") = std::numeric_limits<double>::infinity(), py::arg("delta") = .01, + py::arg("matching") = false, R"pbdoc( Compute the Wasserstein distance between two diagrams. Points at infinity are supported. @@ -55,8 +147,9 @@ PYBIND11_MODULE(wasserstein, m) { order (float): Wasserstein exponent W_q internal_p (float): Internal Minkowski norm L^p in R^2 delta (float): Relative error 1+delta + matching (bool): if ``True``, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention that (-1) represents the diagonal. Returns: - float: Approximate Wasserstein distance W_q(X,Y) + float|Tuple[float,numpy.array]: Approximate Wasserstein distance W_q(X,Y), and optionally the corresponding matching )pbdoc"); } diff --git a/src/python/include/pybind11_diagram_utils.h b/src/python/include/pybind11_diagram_utils.h index 2d5194f4..5cb7c48b 100644 --- a/src/python/include/pybind11_diagram_utils.h +++ b/src/python/include/pybind11_diagram_utils.h @@ -17,16 +17,9 @@ namespace py = pybind11; typedef py::array_t<double> Dgm; -// Get m[i,0] and m[i,1] as a pair -static auto pairify(void* p, py::ssize_t h, py::ssize_t w) { - return [=](py::ssize_t i){ - char* birth = (char*)p + i * h; - char* death = birth + w; - return std::make_pair(*(double*)birth, *(double*)death); - }; -} - -inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) { +// build_point(double birth, double death, ssize_t index) -> Point +template<class BuildPoint> +inline auto numpy_to_range_of_pairs(py::array_t<double> dgm, BuildPoint build_point) { py::buffer_info buf = dgm.request(); // shape (n,2) or (0) for empty if((buf.ndim!=2 || buf.shape[1]!=2) && (buf.ndim!=1 || buf.shape[0]!=0)) @@ -34,6 +27,16 @@ inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) { // In the case of shape (0), avoid reading non-existing strides[1] even if we won't use it. py::ssize_t stride1 = buf.ndim == 2 ? buf.strides[1] : 0; auto cnt = boost::counting_range<py::ssize_t>(0, buf.shape[0]); - return boost::adaptors::transform(cnt, pairify(buf.ptr, buf.strides[0], stride1)); + + char* p = static_cast<char*>(buf.ptr); + auto h = buf.strides[0]; + auto w = stride1; + // Get m[i,0] and m[i,1] as a pair + auto pairify = [=](py::ssize_t i){ + char* birth = p + i * h; + char* death = birth + w; + return build_point(*(double*)birth, *(double*)death, i); + }; + return boost::adaptors::transform(cnt, pairify); // Be careful that the returned range cannot contain references to dead temporaries. } diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 3a004d77..1cac3e1a 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -140,9 +140,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat if test_matching: match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1] - assert np.array_equal(match, []) + # Accept [] or np.array of shape (2, 0) + assert len(match) == 0 match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] - assert np.array_equal(match, []) + assert len(match) == 0 match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1] assert np.array_equal(match , [[-1, 0], [-1, 1]]) match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] @@ -171,10 +172,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat assert (match is None) cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.) assert (cost == 1) - assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway. + assert {(i,j) for i,j in match} == {(0, -1),(1, -1),(-1, 0), (-1, 1), (-1, 2)} # type 4 and 5 are match to the diag anyway. cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.) assert (cost == 0.) - assert (match == [[0, -1], [1, -1]]) + assert np.array_equal(match, [[0, -1], [1, -1]]) def hera_wrap(**extra): @@ -195,6 +196,6 @@ def test_wasserstein_distance_pot(): def test_wasserstein_distance_hera(): - _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) - _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) + _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=True) + _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=True) |