diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2019-12-27 00:56:08 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2019-12-27 00:56:08 +0100 |
commit | b8701d847db37b80a58770e00b91494889df00e8 (patch) | |
tree | 98333269d139fe648b2ae2a14db3fc03af995479 | |
parent | 7568b34c56e6a6102507df1be0029a0259f2afa7 (diff) |
Expose more options
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 4 | ||||
-rw-r--r-- | src/python/gudhi/hera.cc | 31 |
2 files changed, 26 insertions, 9 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 13f6f1af..6cd7f3a0 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -17,8 +17,8 @@ and Cluster for Persistence Diagrams via Optimal Transport". .. autofunction:: gudhi.wasserstein.wasserstein_distance This other implementation comes from `Hera -<https://bitbucket.org/grey_narn/hera/src/master/>`_ and is based on `"Geometry -Helps to Compare Persistence Diagrams." +<https://bitbucket.org/grey_narn/hera/src/master/>`_ (BSD-3-Clause) and is +based on `"Geometry Helps to Compare Persistence Diagrams." <http://dx.doi.org/10.1137/1.9781611974317.9>`_ by Michael Kerber, Dmitriy Morozov, and Arnur Nigmetov, at ALENEX 2016. diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc index 04f5990f..898040fb 100644 --- a/src/python/gudhi/hera.cc +++ b/src/python/gudhi/hera.cc @@ -12,7 +12,6 @@ typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm; namespace hera { template <> struct DiagramTraits<Dgm>{ - //using Container = void; using PointType = std::array<double,2>; using RealType = double; @@ -22,15 +21,17 @@ template <> struct DiagramTraits<Dgm>{ } double wasserstein_distance( - Dgm d1, - Dgm d2) + Dgm d1, Dgm d2, + double wasserstein_power, double internal_p, + double delta) { py::buffer_info buf1 = d1.request(); py::buffer_info buf2 = d2.request(); - if(buf1.ndim!=2 || buf1.shape[1]!=2) - throw std::runtime_error("Diagram 1 must be an array of size n x 2"); - if(buf2.ndim!=2 || buf2.shape[1]!=2) + // shape (n,2) or (0) for empty + if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0)) throw std::runtime_error("Diagram 1 must be an array of size n x 2"); + if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0)) + throw std::runtime_error("Diagram 2 must be an array of size n x 2"); typedef hera::DiagramTraits<Dgm>::PointType Point; auto p1 = (Point*)buf1.ptr; auto p2 = (Point*)buf2.ptr; @@ -38,17 +39,33 @@ double wasserstein_distance( auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]); hera::AuctionParams<double> params; + params.wasserstein_power = wasserstein_power; + // hera encodes infinity as -1... + if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>(); + params.internal_p = internal_p; + params.delta = delta; + // The extra parameters are purposedly not exposed for now. return hera::wasserstein_dist(diag1, diag2, params); } PYBIND11_MODULE(hera, m) { m.def("wasserstein_distance", &wasserstein_distance, py::arg("X"), py::arg("Y"), + // Should we name those q, p and d instead? + py::arg("wasserstein_power") = 1, + py::arg("internal_p") = std::numeric_limits<double>::infinity(), + py::arg("delta") = .01, R"pbdoc( - Compute the Wasserstein distance between two diagrams + Compute the Wasserstein distance between two diagrams. Points at infinity are supported. Parameters: X (n x 2 numpy array): First diagram Y (n x 2 numpy array): Second diagram + wasserstein_power (float): Wasserstein degree W_q + internal_p (float): Internal Minkowski norm L^p in R^2 + delta (float): Relative error 1+delta + + Returns: + float: Approximate Wasserstein distance W_q(X,Y) )pbdoc"); } |