From b8701d847db37b80a58770e00b91494889df00e8 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Fri, 27 Dec 2019 00:56:08 +0100 Subject: Expose more options --- src/python/doc/wasserstein_distance_user.rst | 4 ++-- src/python/gudhi/hera.cc | 31 +++++++++++++++++++++------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 13f6f1af..6cd7f3a0 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -17,8 +17,8 @@ and Cluster for Persistence Diagrams via Optimal Transport". .. autofunction:: gudhi.wasserstein.wasserstein_distance This other implementation comes from `Hera -`_ and is based on `"Geometry -Helps to Compare Persistence Diagrams." +`_ (BSD-3-Clause) and is +based on `"Geometry Helps to Compare Persistence Diagrams." `_ by Michael Kerber, Dmitriy Morozov, and Arnur Nigmetov, at ALENEX 2016. diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc index 04f5990f..898040fb 100644 --- a/src/python/gudhi/hera.cc +++ b/src/python/gudhi/hera.cc @@ -12,7 +12,6 @@ typedef py::array_t Dgm; namespace hera { template <> struct DiagramTraits{ - //using Container = void; using PointType = std::array; using RealType = double; @@ -22,15 +21,17 @@ template <> struct DiagramTraits{ } double wasserstein_distance( - Dgm d1, - Dgm d2) + Dgm d1, Dgm d2, + double wasserstein_power, double internal_p, + double delta) { py::buffer_info buf1 = d1.request(); py::buffer_info buf2 = d2.request(); - if(buf1.ndim!=2 || buf1.shape[1]!=2) - throw std::runtime_error("Diagram 1 must be an array of size n x 2"); - if(buf2.ndim!=2 || buf2.shape[1]!=2) + // shape (n,2) or (0) for empty + if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0)) throw std::runtime_error("Diagram 1 must be an array of size n x 2"); + if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0)) + throw std::runtime_error("Diagram 2 must be an array of size n x 2"); typedef hera::DiagramTraits::PointType Point; auto p1 = (Point*)buf1.ptr; auto p2 = (Point*)buf2.ptr; @@ -38,17 +39,33 @@ double wasserstein_distance( auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]); hera::AuctionParams params; + params.wasserstein_power = wasserstein_power; + // hera encodes infinity as -1... + if(std::isinf(internal_p)) internal_p = hera::get_infinity(); + params.internal_p = internal_p; + params.delta = delta; + // The extra parameters are purposedly not exposed for now. return hera::wasserstein_dist(diag1, diag2, params); } PYBIND11_MODULE(hera, m) { m.def("wasserstein_distance", &wasserstein_distance, py::arg("X"), py::arg("Y"), + // Should we name those q, p and d instead? + py::arg("wasserstein_power") = 1, + py::arg("internal_p") = std::numeric_limits::infinity(), + py::arg("delta") = .01, R"pbdoc( - Compute the Wasserstein distance between two diagrams + Compute the Wasserstein distance between two diagrams. Points at infinity are supported. Parameters: X (n x 2 numpy array): First diagram Y (n x 2 numpy array): Second diagram + wasserstein_power (float): Wasserstein degree W_q + internal_p (float): Internal Minkowski norm L^p in R^2 + delta (float): Relative error 1+delta + + Returns: + float: Approximate Wasserstein distance W_q(X,Y) )pbdoc"); } -- cgit v1.2.3