summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2022-11-16 00:20:04 +0100
committerMarc Glisse <marc.glisse@inria.fr>2022-11-16 14:06:50 +0100
commit8b658271dd38f1aaffbe94be8978cf5cea8ec7de (patch)
tree67e0c2af2d3586e7120d8cea88aee991ec662976
parent04370bae13251d0bcce205f253fb758f91fdf207 (diff)
Output matching from hera.wasserstein_distance
-rw-r--r--src/python/gudhi/bottleneck.cc7
-rw-r--r--src/python/gudhi/hera/bottleneck.cc9
-rw-r--r--src/python/gudhi/hera/wasserstein.cc127
-rw-r--r--src/python/include/pybind11_diagram_utils.h25
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py13
5 files changed, 142 insertions, 39 deletions
diff --git a/src/python/gudhi/bottleneck.cc b/src/python/gudhi/bottleneck.cc
index 8a3d669a..3a0fe473 100644
--- a/src/python/gudhi/bottleneck.cc
+++ b/src/python/gudhi/bottleneck.cc
@@ -12,6 +12,9 @@
#include <pybind11_diagram_utils.h>
+// Indices are added internally in bottleneck_distance, they are not needed in the input.
+static auto make_point(double x, double y, py::ssize_t) { return std::pair(x, y); };
+
// For compatibility with older versions, we want to support e=None.
// In C++17, the recommended way is std::optional<double>.
double bottleneck(Dgm d1, Dgm d2, py::object epsilon)
@@ -19,8 +22,8 @@ double bottleneck(Dgm d1, Dgm d2, py::object epsilon)
double e = (std::numeric_limits<double>::min)();
if (!epsilon.is_none()) e = epsilon.cast<double>();
// I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
+ auto diag1 = numpy_to_range_of_pairs(d1, make_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_point);
py::gil_scoped_release release;
diff --git a/src/python/gudhi/hera/bottleneck.cc b/src/python/gudhi/hera/bottleneck.cc
index ec461f7c..6b919b02 100644
--- a/src/python/gudhi/hera/bottleneck.cc
+++ b/src/python/gudhi/hera/bottleneck.cc
@@ -16,13 +16,16 @@
using py::ssize_t;
#endif
-#include <hera/bottleneck.h> // Hera
+#include <hera/bottleneck.h>
+
+// Indices are added internally in bottleneck_distance, they are not needed in the input.
+static auto make_point(double x, double y, py::ssize_t) { return std::pair(x, y); };
double bottleneck_distance(Dgm d1, Dgm d2, double delta)
{
// I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
+ auto diag1 = numpy_to_range_of_pairs(d1, make_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_point);
py::gil_scoped_release release;
diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc
index 3516352e..8c731530 100644
--- a/src/python/gudhi/hera/wasserstein.cc
+++ b/src/python/gudhi/hera/wasserstein.cc
@@ -16,27 +16,118 @@
using py::ssize_t;
#endif
-#include <hera/wasserstein.h> // Hera
+#include <hera/wasserstein.h>
+#include <gudhi/Debug_utils.h>
-double wasserstein_distance(
+// Unlike bottleneck, for wasserstein, we need to add the index ourselves (if we want the matching)
+static auto make_hera_point(double x, double y, py::ssize_t i) { return hera::DiagramPoint<double>(x, y, i); };
+
+py::object wasserstein_distance(
Dgm d1, Dgm d2,
double wasserstein_power, double internal_p,
- double delta)
+ double delta, bool return_matching)
{
// I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
-
- py::gil_scoped_release release;
-
- hera::AuctionParams<double> params;
- params.wasserstein_power = wasserstein_power;
- // hera encodes infinity as -1...
- if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
- params.internal_p = internal_p;
- params.delta = delta;
- // The extra parameters are purposely not exposed for now.
- return hera::wasserstein_dist(diag1, diag2, params);
+ auto diag1 = numpy_to_range_of_pairs(d1, make_hera_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_hera_point);
+ int n1 = boost::size(diag1);
+ int n2 = boost::size(diag2);
+ hera::AuctionResult<double> res;
+ double dist;
+
+ { // No Python allowed in this section
+ py::gil_scoped_release release;
+
+ hera::AuctionParams<double> params;
+ params.wasserstein_power = wasserstein_power;
+ // hera encodes infinity as -1...
+ if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
+ params.internal_p = internal_p;
+ params.delta = delta;
+ if(return_matching) {
+ params.return_matching = true;
+ params.match_inf_points = true;
+ }
+ // The extra parameters are purposely not exposed for now.
+ res = hera::wasserstein_cost_detailed(diag1, diag2, params);
+ dist = std::pow(res.cost, 1./params.wasserstein_power);
+ }
+
+ if(!return_matching)
+ return py::cast(dist);
+
+ if(dist == std::numeric_limits<double>::infinity())
+ return py::make_tuple(dist, py::none());
+
+ // bug in Hera, matching_a_to_b_ is empty if one diagram is empty or both diagrams contain the same points
+ if(res.matching_a_to_b_.size() == 0) {
+ if(n1 == 0) { // diag1 is empty
+ py::array_t<int> matching({{ n2, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int j=0; j<n2; ++j){
+ m(j, 0) = -1;
+ m(j, 1) = j;
+ }
+ return py::make_tuple(dist, matching);
+ }
+ if(n2 == 0) { // diag2 is empty
+ py::array_t<int> matching({{ n1, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int i=0; i<n1; ++i){
+ m(i, 0) = i;
+ m(i, 1) = -1;
+ }
+ return py::make_tuple(dist, matching);
+ }
+ // The only remaining case should be that the 2 diagrams are identical, but possibly shuffled
+ GUDHI_CHECK(n1==n2, "unexpected bug in Hera?");
+ std::vector v1(boost::begin(diag1), boost::end(diag1));
+ std::vector v2(boost::begin(diag2), boost::end(diag2));
+ std::sort(v1.begin(), v1.end());
+ std::sort(v2.begin(), v2.end());
+ py::array_t<int> matching({{ n1, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int i=0; i<n1; ++i){
+ GUDHI_CHECK(v1[i][0]==v2[i][0] && v1[i][1]==v2[i][1], "unexpected bug in Hera?");
+ m(i, 0) = v1[i].get_id();
+ m(i, 1) = v2[i].get_id();
+ }
+ return py::make_tuple(dist, matching);
+
+ }
+
+ // bug in Hera, diagonal points are ignored and don't appear in matching_a_to_b_
+ for(auto p : diag1)
+ if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[id] = -id-1; }
+ for(auto p : diag2)
+ if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[-id-1] = id; }
+
+ py::array_t<int> matching({{ n1 + n2, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ int cur = 0;
+ for(auto x : res.matching_a_to_b_){
+ if(x.first < 0) {
+ if(x.second < 0) {
+ } else {
+ m(cur, 0) = -1;
+ m(cur, 1) = x.second;
+ ++cur;
+ }
+ } else {
+ if(x.second < 0) {
+ m(cur, 0) = x.first;
+ m(cur, 1) = -1;
+ ++cur;
+ } else {
+ m(cur, 0) = x.first;
+ m(cur, 1) = x.second;
+ ++cur;
+ }
+ }
+ }
+ // n1+n2 was too much, it only happens if everything matches to the diagonal, so we return matching[:cur,:]
+ py::array_t<int> ret({{ cur, 2 }}, {{ matching.strides()[0], matching.strides()[1] }}, matching.data(), matching);
+ return py::make_tuple(dist, ret);
}
PYBIND11_MODULE(wasserstein, m) {
@@ -45,6 +136,7 @@ PYBIND11_MODULE(wasserstein, m) {
py::arg("order") = 1,
py::arg("internal_p") = std::numeric_limits<double>::infinity(),
py::arg("delta") = .01,
+ py::arg("matching") = false,
R"pbdoc(
Compute the Wasserstein distance between two diagrams.
Points at infinity are supported.
@@ -55,8 +147,9 @@ PYBIND11_MODULE(wasserstein, m) {
order (float): Wasserstein exponent W_q
internal_p (float): Internal Minkowski norm L^p in R^2
delta (float): Relative error 1+delta
+ matching (bool): if ``True``, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention that (-1) represents the diagonal.
Returns:
- float: Approximate Wasserstein distance W_q(X,Y)
+ float|Tuple[float,numpy.array]: Approximate Wasserstein distance W_q(X,Y), and optionally the corresponding matching
)pbdoc");
}
diff --git a/src/python/include/pybind11_diagram_utils.h b/src/python/include/pybind11_diagram_utils.h
index 2d5194f4..5cb7c48b 100644
--- a/src/python/include/pybind11_diagram_utils.h
+++ b/src/python/include/pybind11_diagram_utils.h
@@ -17,16 +17,9 @@
namespace py = pybind11;
typedef py::array_t<double> Dgm;
-// Get m[i,0] and m[i,1] as a pair
-static auto pairify(void* p, py::ssize_t h, py::ssize_t w) {
- return [=](py::ssize_t i){
- char* birth = (char*)p + i * h;
- char* death = birth + w;
- return std::make_pair(*(double*)birth, *(double*)death);
- };
-}
-
-inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) {
+// build_point(double birth, double death, ssize_t index) -> Point
+template<class BuildPoint>
+inline auto numpy_to_range_of_pairs(py::array_t<double> dgm, BuildPoint build_point) {
py::buffer_info buf = dgm.request();
// shape (n,2) or (0) for empty
if((buf.ndim!=2 || buf.shape[1]!=2) && (buf.ndim!=1 || buf.shape[0]!=0))
@@ -34,6 +27,16 @@ inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) {
// In the case of shape (0), avoid reading non-existing strides[1] even if we won't use it.
py::ssize_t stride1 = buf.ndim == 2 ? buf.strides[1] : 0;
auto cnt = boost::counting_range<py::ssize_t>(0, buf.shape[0]);
- return boost::adaptors::transform(cnt, pairify(buf.ptr, buf.strides[0], stride1));
+
+ char* p = static_cast<char*>(buf.ptr);
+ auto h = buf.strides[0];
+ auto w = stride1;
+ // Get m[i,0] and m[i,1] as a pair
+ auto pairify = [=](py::ssize_t i){
+ char* birth = p + i * h;
+ char* death = birth + w;
+ return build_point(*(double*)birth, *(double*)death, i);
+ };
+ return boost::adaptors::transform(cnt, pairify);
// Be careful that the returned range cannot contain references to dead temporaries.
}
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 3a004d77..1cac3e1a 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -140,9 +140,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
if test_matching:
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
- assert np.array_equal(match, [])
+ # Accept [] or np.array of shape (2, 0)
+ assert len(match) == 0
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert np.array_equal(match, [])
+ assert len(match) == 0
match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
assert np.array_equal(match , [[-1, 0], [-1, 1]])
match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
@@ -171,10 +172,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
assert (match is None)
cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
assert (cost == 1)
- assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
+ assert {(i,j) for i,j in match} == {(0, -1),(1, -1),(-1, 0), (-1, 1), (-1, 2)} # type 4 and 5 are match to the diag anyway.
cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
assert (cost == 0.)
- assert (match == [[0, -1], [1, -1]])
+ assert np.array_equal(match, [[0, -1], [1, -1]])
def hera_wrap(**extra):
@@ -195,6 +196,6 @@ def test_wasserstein_distance_pot():
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
- _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=True)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=True)