summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2019-12-27 00:56:08 +0100
committerMarc Glisse <marc.glisse@inria.fr>2019-12-27 00:56:08 +0100
commitb8701d847db37b80a58770e00b91494889df00e8 (patch)
tree98333269d139fe648b2ae2a14db3fc03af995479
parent7568b34c56e6a6102507df1be0029a0259f2afa7 (diff)
Expose more options
-rw-r--r--src/python/doc/wasserstein_distance_user.rst4
-rw-r--r--src/python/gudhi/hera.cc31
2 files changed, 26 insertions, 9 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index 13f6f1af..6cd7f3a0 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -17,8 +17,8 @@ and Cluster for Persistence Diagrams via Optimal Transport".
.. autofunction:: gudhi.wasserstein.wasserstein_distance
This other implementation comes from `Hera
-<https://bitbucket.org/grey_narn/hera/src/master/>`_ and is based on `"Geometry
-Helps to Compare Persistence Diagrams."
+<https://bitbucket.org/grey_narn/hera/src/master/>`_ (BSD-3-Clause) and is
+based on `"Geometry Helps to Compare Persistence Diagrams."
<http://dx.doi.org/10.1137/1.9781611974317.9>`_ by Michael Kerber, Dmitriy
Morozov, and Arnur Nigmetov, at ALENEX 2016.
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
index 04f5990f..898040fb 100644
--- a/src/python/gudhi/hera.cc
+++ b/src/python/gudhi/hera.cc
@@ -12,7 +12,6 @@ typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm;
namespace hera {
template <> struct DiagramTraits<Dgm>{
- //using Container = void;
using PointType = std::array<double,2>;
using RealType = double;
@@ -22,15 +21,17 @@ template <> struct DiagramTraits<Dgm>{
}
double wasserstein_distance(
- Dgm d1,
- Dgm d2)
+ Dgm d1, Dgm d2,
+ double wasserstein_power, double internal_p,
+ double delta)
{
py::buffer_info buf1 = d1.request();
py::buffer_info buf2 = d2.request();
- if(buf1.ndim!=2 || buf1.shape[1]!=2)
- throw std::runtime_error("Diagram 1 must be an array of size n x 2");
- if(buf2.ndim!=2 || buf2.shape[1]!=2)
+ // shape (n,2) or (0) for empty
+ if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0))
throw std::runtime_error("Diagram 1 must be an array of size n x 2");
+ if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0))
+ throw std::runtime_error("Diagram 2 must be an array of size n x 2");
typedef hera::DiagramTraits<Dgm>::PointType Point;
auto p1 = (Point*)buf1.ptr;
auto p2 = (Point*)buf2.ptr;
@@ -38,17 +39,33 @@ double wasserstein_distance(
auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]);
hera::AuctionParams<double> params;
+ params.wasserstein_power = wasserstein_power;
+ // hera encodes infinity as -1...
+ if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
+ params.internal_p = internal_p;
+ params.delta = delta;
+ // The extra parameters are purposedly not exposed for now.
return hera::wasserstein_dist(diag1, diag2, params);
}
PYBIND11_MODULE(hera, m) {
m.def("wasserstein_distance", &wasserstein_distance,
py::arg("X"), py::arg("Y"),
+ // Should we name those q, p and d instead?
+ py::arg("wasserstein_power") = 1,
+ py::arg("internal_p") = std::numeric_limits<double>::infinity(),
+ py::arg("delta") = .01,
R"pbdoc(
- Compute the Wasserstein distance between two diagrams
+ Compute the Wasserstein distance between two diagrams. Points at infinity are supported.
Parameters:
X (n x 2 numpy array): First diagram
Y (n x 2 numpy array): Second diagram
+ wasserstein_power (float): Wasserstein degree W_q
+ internal_p (float): Internal Minkowski norm L^p in R^2
+ delta (float): Relative error 1+delta
+
+ Returns:
+ float: Approximate Wasserstein distance W_q(X,Y)
)pbdoc");
}