summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com>2023-01-11 10:48:23 +0100
committerGitHub <noreply@github.com>2023-01-11 10:48:23 +0100
commite22566a6721c050c964c5ad1f29a3e9066c12ba9 (patch)
tree828d24784dfc0b17889aa5ad84fd983d12a70c0c
parent7b9c272f7dad66ad9b05dcdd5ec43e86fda306c4 (diff)
parenteb6c94ea1125bf216ec5f07b2936dd115e461aa4 (diff)
Merge pull request #736 from mglisse/hera-match
Provide matching in hera.wasserstein_distance
-rw-r--r--src/python/gudhi/bottleneck.cc18
-rw-r--r--src/python/gudhi/hera/bottleneck.cc11
-rw-r--r--src/python/gudhi/hera/wasserstein.cc129
-rw-r--r--src/python/include/pybind11_diagram_utils.h25
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py28
5 files changed, 160 insertions, 51 deletions
diff --git a/src/python/gudhi/bottleneck.cc b/src/python/gudhi/bottleneck.cc
index 8a3d669a..040e6d37 100644
--- a/src/python/gudhi/bottleneck.cc
+++ b/src/python/gudhi/bottleneck.cc
@@ -9,18 +9,20 @@
*/
#include <gudhi/Bottleneck.h>
-
+#include <optional>
#include <pybind11_diagram_utils.h>
+#include <pybind11/stl.h>
+
+// Indices are added internally in bottleneck_distance, they are not needed in the input.
+static auto make_point(double x, double y, py::ssize_t) { return std::pair(x, y); };
// For compatibility with older versions, we want to support e=None.
-// In C++17, the recommended way is std::optional<double>.
-double bottleneck(Dgm d1, Dgm d2, py::object epsilon)
+double bottleneck(Dgm d1, Dgm d2, std::optional<double> epsilon)
{
- double e = (std::numeric_limits<double>::min)();
- if (!epsilon.is_none()) e = epsilon.cast<double>();
- // I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
+ double e = epsilon.value_or((std::numeric_limits<double>::min)());
+ // I *think* the call to request() in numpy_to_range_of_pairs has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1, make_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_point);
py::gil_scoped_release release;
diff --git a/src/python/gudhi/hera/bottleneck.cc b/src/python/gudhi/hera/bottleneck.cc
index ec461f7c..9826252c 100644
--- a/src/python/gudhi/hera/bottleneck.cc
+++ b/src/python/gudhi/hera/bottleneck.cc
@@ -16,13 +16,16 @@
using py::ssize_t;
#endif
-#include <hera/bottleneck.h> // Hera
+#include <hera/bottleneck.h>
+
+// Indices are added internally in bottleneck_distance, they are not needed in the input.
+static auto make_point(double x, double y, py::ssize_t) { return std::pair(x, y); };
double bottleneck_distance(Dgm d1, Dgm d2, double delta)
{
- // I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
+ // I *think* the call to request() in numpy_to_range_of_pairs has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1, make_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_point);
py::gil_scoped_release release;
diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc
index 3516352e..41e84f7b 100644
--- a/src/python/gudhi/hera/wasserstein.cc
+++ b/src/python/gudhi/hera/wasserstein.cc
@@ -16,27 +16,118 @@
using py::ssize_t;
#endif
-#include <hera/wasserstein.h> // Hera
+#include <hera/wasserstein.h>
+#include <gudhi/Debug_utils.h>
-double wasserstein_distance(
+// Unlike bottleneck, for wasserstein, we need to add the index ourselves (if we want the matching)
+static auto make_hera_point(double x, double y, py::ssize_t i) { return hera::DiagramPoint<double>(x, y, i); };
+
+py::object wasserstein_distance(
Dgm d1, Dgm d2,
double wasserstein_power, double internal_p,
- double delta)
+ double delta, bool return_matching)
{
- // I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
-
- py::gil_scoped_release release;
-
- hera::AuctionParams<double> params;
- params.wasserstein_power = wasserstein_power;
- // hera encodes infinity as -1...
- if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
- params.internal_p = internal_p;
- params.delta = delta;
- // The extra parameters are purposely not exposed for now.
- return hera::wasserstein_dist(diag1, diag2, params);
+ // I *think* the call to request() in numpy_to_range_of_pairs has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1, make_hera_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_hera_point);
+ int n1 = boost::size(diag1);
+ int n2 = boost::size(diag2);
+ hera::AuctionResult<double> res;
+ double dist;
+
+ { // No Python allowed in this section
+ py::gil_scoped_release release;
+
+ hera::AuctionParams<double> params;
+ params.wasserstein_power = wasserstein_power;
+ // hera encodes infinity as -1...
+ if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
+ params.internal_p = internal_p;
+ params.delta = delta;
+ if(return_matching) {
+ params.return_matching = true;
+ params.match_inf_points = true;
+ }
+ // The extra parameters are purposely not exposed for now.
+ res = hera::wasserstein_cost_detailed(diag1, diag2, params);
+ dist = std::pow(res.cost, 1./params.wasserstein_power);
+ }
+
+ if(!return_matching)
+ return py::cast(dist);
+
+ if(dist == std::numeric_limits<double>::infinity())
+ return py::make_tuple(dist, py::none());
+
+ // bug in Hera, matching_a_to_b_ is empty if one diagram is empty or both diagrams contain the same points
+ if(res.matching_a_to_b_.size() == 0) {
+ if(n1 == 0) { // diag1 is empty
+ py::array_t<int> matching({{ n2, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int j=0; j<n2; ++j){
+ m(j, 0) = -1;
+ m(j, 1) = j;
+ }
+ return py::make_tuple(dist, matching);
+ }
+ if(n2 == 0) { // diag2 is empty
+ py::array_t<int> matching({{ n1, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int i=0; i<n1; ++i){
+ m(i, 0) = i;
+ m(i, 1) = -1;
+ }
+ return py::make_tuple(dist, matching);
+ }
+ // The only remaining case should be that the 2 diagrams are identical, but possibly shuffled
+ GUDHI_CHECK(n1==n2, "unexpected bug in Hera?");
+ std::vector v1(boost::begin(diag1), boost::end(diag1));
+ std::vector v2(boost::begin(diag2), boost::end(diag2));
+ std::sort(v1.begin(), v1.end());
+ std::sort(v2.begin(), v2.end());
+ py::array_t<int> matching({{ n1, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int i=0; i<n1; ++i){
+ GUDHI_CHECK(v1[i][0]==v2[i][0] && v1[i][1]==v2[i][1], "unexpected bug in Hera?");
+ m(i, 0) = v1[i].get_id();
+ m(i, 1) = v2[i].get_id();
+ }
+ return py::make_tuple(dist, matching);
+
+ }
+
+ // bug in Hera, diagonal points are ignored and don't appear in matching_a_to_b_
+ for(auto p : diag1)
+ if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[id] = -id-1; }
+ for(auto p : diag2)
+ if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[-id-1] = id; }
+
+ py::array_t<int> matching({{ n1 + n2, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ int cur = 0;
+ for(auto x : res.matching_a_to_b_){
+ if(x.first < 0) {
+ if(x.second < 0) {
+ } else {
+ m(cur, 0) = -1;
+ m(cur, 1) = x.second;
+ ++cur;
+ }
+ } else {
+ if(x.second < 0) {
+ m(cur, 0) = x.first;
+ m(cur, 1) = -1;
+ ++cur;
+ } else {
+ m(cur, 0) = x.first;
+ m(cur, 1) = x.second;
+ ++cur;
+ }
+ }
+ }
+ // n1+n2 was too much, it only happens if everything matches to the diagonal, so we return matching[:cur,:]
+ py::array_t<int> ret({{ cur, 2 }}, {{ matching.strides()[0], matching.strides()[1] }}, matching.data(), matching);
+ return py::make_tuple(dist, ret);
}
PYBIND11_MODULE(wasserstein, m) {
@@ -45,6 +136,7 @@ PYBIND11_MODULE(wasserstein, m) {
py::arg("order") = 1,
py::arg("internal_p") = std::numeric_limits<double>::infinity(),
py::arg("delta") = .01,
+ py::arg("matching") = false,
R"pbdoc(
Compute the Wasserstein distance between two diagrams.
Points at infinity are supported.
@@ -55,8 +147,9 @@ PYBIND11_MODULE(wasserstein, m) {
order (float): Wasserstein exponent W_q
internal_p (float): Internal Minkowski norm L^p in R^2
delta (float): Relative error 1+delta
+ matching (bool): if ``True``, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention that (-1) represents the diagonal. If the distance between two diagrams is +inf (which happens if the cardinalities of essential parts differ) and the matching is requested, it will be set to ``None`` (any matching is optimal).
Returns:
- float: Approximate Wasserstein distance W_q(X,Y)
+ float|Tuple[float,numpy.array|None]: Approximate Wasserstein distance W_q(X,Y), and optionally the corresponding matching
)pbdoc");
}
diff --git a/src/python/include/pybind11_diagram_utils.h b/src/python/include/pybind11_diagram_utils.h
index 2d5194f4..5cb7c48b 100644
--- a/src/python/include/pybind11_diagram_utils.h
+++ b/src/python/include/pybind11_diagram_utils.h
@@ -17,16 +17,9 @@
namespace py = pybind11;
typedef py::array_t<double> Dgm;
-// Get m[i,0] and m[i,1] as a pair
-static auto pairify(void* p, py::ssize_t h, py::ssize_t w) {
- return [=](py::ssize_t i){
- char* birth = (char*)p + i * h;
- char* death = birth + w;
- return std::make_pair(*(double*)birth, *(double*)death);
- };
-}
-
-inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) {
+// build_point(double birth, double death, ssize_t index) -> Point
+template<class BuildPoint>
+inline auto numpy_to_range_of_pairs(py::array_t<double> dgm, BuildPoint build_point) {
py::buffer_info buf = dgm.request();
// shape (n,2) or (0) for empty
if((buf.ndim!=2 || buf.shape[1]!=2) && (buf.ndim!=1 || buf.shape[0]!=0))
@@ -34,6 +27,16 @@ inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) {
// In the case of shape (0), avoid reading non-existing strides[1] even if we won't use it.
py::ssize_t stride1 = buf.ndim == 2 ? buf.strides[1] : 0;
auto cnt = boost::counting_range<py::ssize_t>(0, buf.shape[0]);
- return boost::adaptors::transform(cnt, pairify(buf.ptr, buf.strides[0], stride1));
+
+ char* p = static_cast<char*>(buf.ptr);
+ auto h = buf.strides[0];
+ auto w = stride1;
+ // Get m[i,0] and m[i,1] as a pair
+ auto pairify = [=](py::ssize_t i){
+ char* birth = p + i * h;
+ char* death = birth + w;
+ return build_point(*(double*)birth, *(double*)death, i);
+ };
+ return boost::adaptors::transform(cnt, pairify);
// Be careful that the returned range cannot contain references to dead temporaries.
}
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index a76b6ce7..42bf3299 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -97,6 +97,13 @@ def test_warn_infty():
assert (m is None)
+def _to_set(X):
+ return { (i, j) for i, j in X }
+
+def _same_permuted(X, Y):
+ return _to_set(X) == _to_set(Y)
+
+
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
@@ -141,15 +148,16 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
if test_matching:
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
- assert np.array_equal(match, [])
+ # Accept [] or np.array of shape (2, 0)
+ assert len(match) == 0
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert np.array_equal(match, [])
+ assert len(match) == 0
match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
- assert np.array_equal(match , [[-1, 0], [-1, 1]])
+ assert _same_permuted(match, [[-1, 0], [-1, 1]])
match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert np.array_equal(match , [[0, -1], [1, -1]])
+ assert _same_permuted(match, [[0, -1], [1, -1]])
match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
- assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
+ assert _same_permuted(match, [[0, 0], [1, 1], [2, -1]])
if test_matching and test_infinity:
diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]])
@@ -158,7 +166,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]])
match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1]
- assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
+ assert _same_permuted(match, [[0, -1], [-1,0], [-1, 1], [1, 2]])
match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1]
assert (match is None)
cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3)
@@ -172,10 +180,10 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
assert (match is None)
cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.)
assert (cost == 1)
- assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
+ assert _same_permuted(match, [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway.
cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.)
assert (cost == 0.)
- assert (match == [[0, -1], [1, -1]])
+ assert _same_permuted(match, [[0, -1], [1, -1]])
def hera_wrap(**extra):
@@ -196,6 +204,6 @@ def test_wasserstein_distance_pot():
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
- _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=True)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=True)