summaryrefslogtreecommitdiff
path: root/src/python/gudhi/hera/wasserstein.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/hera/wasserstein.cc')
-rw-r--r--src/python/gudhi/hera/wasserstein.cc137
1 files changed, 118 insertions, 19 deletions
diff --git a/src/python/gudhi/hera/wasserstein.cc b/src/python/gudhi/hera/wasserstein.cc
index 1a21f02f..41e84f7b 100644
--- a/src/python/gudhi/hera/wasserstein.cc
+++ b/src/python/gudhi/hera/wasserstein.cc
@@ -8,29 +8,126 @@
* - YYYY/MM Author: Description of the modification
*/
-#include <wasserstein.h> // Hera
-
#include <pybind11_diagram_utils.h>
-double wasserstein_distance(
+#ifdef _MSC_VER
+// https://github.com/grey-narn/hera/issues/3
+// ssize_t is a non-standard type (well, posix)
+using py::ssize_t;
+#endif
+
+#include <hera/wasserstein.h>
+#include <gudhi/Debug_utils.h>
+
+// Unlike bottleneck, for wasserstein, we need to add the index ourselves (if we want the matching)
+static auto make_hera_point(double x, double y, py::ssize_t i) { return hera::DiagramPoint<double>(x, y, i); };
+
+py::object wasserstein_distance(
Dgm d1, Dgm d2,
double wasserstein_power, double internal_p,
- double delta)
+ double delta, bool return_matching)
{
- // I *think* the call to request() has to be before releasing the GIL.
- auto diag1 = numpy_to_range_of_pairs(d1);
- auto diag2 = numpy_to_range_of_pairs(d2);
-
- py::gil_scoped_release release;
-
- hera::AuctionParams<double> params;
- params.wasserstein_power = wasserstein_power;
- // hera encodes infinity as -1...
- if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
- params.internal_p = internal_p;
- params.delta = delta;
- // The extra parameters are purposedly not exposed for now.
- return hera::wasserstein_dist(diag1, diag2, params);
+ // I *think* the call to request() in numpy_to_range_of_pairs has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1, make_hera_point);
+ auto diag2 = numpy_to_range_of_pairs(d2, make_hera_point);
+ int n1 = boost::size(diag1);
+ int n2 = boost::size(diag2);
+ hera::AuctionResult<double> res;
+ double dist;
+
+ { // No Python allowed in this section
+ py::gil_scoped_release release;
+
+ hera::AuctionParams<double> params;
+ params.wasserstein_power = wasserstein_power;
+ // hera encodes infinity as -1...
+ if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
+ params.internal_p = internal_p;
+ params.delta = delta;
+ if(return_matching) {
+ params.return_matching = true;
+ params.match_inf_points = true;
+ }
+ // The extra parameters are purposely not exposed for now.
+ res = hera::wasserstein_cost_detailed(diag1, diag2, params);
+ dist = std::pow(res.cost, 1./params.wasserstein_power);
+ }
+
+ if(!return_matching)
+ return py::cast(dist);
+
+ if(dist == std::numeric_limits<double>::infinity())
+ return py::make_tuple(dist, py::none());
+
+ // bug in Hera, matching_a_to_b_ is empty if one diagram is empty or both diagrams contain the same points
+ if(res.matching_a_to_b_.size() == 0) {
+ if(n1 == 0) { // diag1 is empty
+ py::array_t<int> matching({{ n2, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int j=0; j<n2; ++j){
+ m(j, 0) = -1;
+ m(j, 1) = j;
+ }
+ return py::make_tuple(dist, matching);
+ }
+ if(n2 == 0) { // diag2 is empty
+ py::array_t<int> matching({{ n1, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int i=0; i<n1; ++i){
+ m(i, 0) = i;
+ m(i, 1) = -1;
+ }
+ return py::make_tuple(dist, matching);
+ }
+ // The only remaining case should be that the 2 diagrams are identical, but possibly shuffled
+ GUDHI_CHECK(n1==n2, "unexpected bug in Hera?");
+ std::vector v1(boost::begin(diag1), boost::end(diag1));
+ std::vector v2(boost::begin(diag2), boost::end(diag2));
+ std::sort(v1.begin(), v1.end());
+ std::sort(v2.begin(), v2.end());
+ py::array_t<int> matching({{ n1, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ for(int i=0; i<n1; ++i){
+ GUDHI_CHECK(v1[i][0]==v2[i][0] && v1[i][1]==v2[i][1], "unexpected bug in Hera?");
+ m(i, 0) = v1[i].get_id();
+ m(i, 1) = v2[i].get_id();
+ }
+ return py::make_tuple(dist, matching);
+
+ }
+
+ // bug in Hera, diagonal points are ignored and don't appear in matching_a_to_b_
+ for(auto p : diag1)
+ if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[id] = -id-1; }
+ for(auto p : diag2)
+ if(p[0] == p[1]) { auto id = p.get_id(); res.matching_a_to_b_[-id-1] = id; }
+
+ py::array_t<int> matching({{ n1 + n2, 2 }}, nullptr);
+ auto m = matching.mutable_unchecked();
+ int cur = 0;
+ for(auto x : res.matching_a_to_b_){
+ if(x.first < 0) {
+ if(x.second < 0) {
+ } else {
+ m(cur, 0) = -1;
+ m(cur, 1) = x.second;
+ ++cur;
+ }
+ } else {
+ if(x.second < 0) {
+ m(cur, 0) = x.first;
+ m(cur, 1) = -1;
+ ++cur;
+ } else {
+ m(cur, 0) = x.first;
+ m(cur, 1) = x.second;
+ ++cur;
+ }
+ }
+ }
+ // n1+n2 was too much, it only happens if everything matches to the diagonal, so we return matching[:cur,:]
+ py::array_t<int> ret({{ cur, 2 }}, {{ matching.strides()[0], matching.strides()[1] }}, matching.data(), matching);
+ return py::make_tuple(dist, ret);
}
PYBIND11_MODULE(wasserstein, m) {
@@ -39,6 +136,7 @@ PYBIND11_MODULE(wasserstein, m) {
py::arg("order") = 1,
py::arg("internal_p") = std::numeric_limits<double>::infinity(),
py::arg("delta") = .01,
+ py::arg("matching") = false,
R"pbdoc(
Compute the Wasserstein distance between two diagrams.
Points at infinity are supported.
@@ -49,8 +147,9 @@ PYBIND11_MODULE(wasserstein, m) {
order (float): Wasserstein exponent W_q
internal_p (float): Internal Minkowski norm L^p in R^2
delta (float): Relative error 1+delta
+ matching (bool): if ``True``, computes and returns the optimal matching between X and Y, encoded as a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to the j-th point in Y, with the convention that (-1) represents the diagonal. If the distance between two diagrams is +inf (which happens if the cardinalities of essential parts differ) and the matching is requested, it will be set to ``None`` (any matching is optimal).
Returns:
- float: Approximate Wasserstein distance W_q(X,Y)
+ float|Tuple[float,numpy.array|None]: Approximate Wasserstein distance W_q(X,Y), and optionally the corresponding matching
)pbdoc");
}