summaryrefslogtreecommitdiff
path: root/src/python/gudhi/hera.cc
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2019-12-26 10:22:47 +0100
committerMarc Glisse <marc.glisse@inria.fr>2019-12-26 10:22:47 +0100
commitc2e22942c35e894d5c1ddc429eb32687c61538c8 (patch)
treeaaf991edcadf848a78122fcaa8c94fbf6d54933e /src/python/gudhi/hera.cc
parent15f222eecf3b427c59f09ec3bec17983377d96a2 (diff)
Basic binding for wasserstein_distance
Diffstat (limited to 'src/python/gudhi/hera.cc')
-rw-r--r--src/python/gudhi/hera.cc48
1 files changed, 48 insertions, 0 deletions
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
new file mode 100644
index 00000000..7cef9425
--- /dev/null
+++ b/src/python/gudhi/hera.cc
@@ -0,0 +1,48 @@
+#include <pybind11/pybind11.h>
+#include <pybind11/numpy.h>
+
+#include <boost/range/iterator_range.hpp>
+
+#include <wasserstein.h>
+
+#include <array>
+
+namespace py = pybind11;
+typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm;
+
+namespace hera {
+template <> struct DiagramTraits<Dgm>{
+ //using Container = void;
+ using PointType = std::array<double,2>;
+ using RealType = double;
+
+ static RealType get_x(const PointType& p) { return std::get<0>(p); }
+ static RealType get_y(const PointType& p) { return std::get<1>(p); }
+};
+}
+
+double wasserstein_distance(
+ Dgm d1,
+ Dgm d2)
+{
+ py::buffer_info buf1 = d1.request();
+ py::buffer_info buf2 = d2.request();
+ if(buf1.ndim!=2 || buf1.shape[1]!=2)
+ throw std::runtime_error("Diagram 1 must be an array of size n x 2");
+ if(buf2.ndim!=2 || buf2.shape[1]!=2)
+ throw std::runtime_error("Diagram 1 must be an array of size n x 2");
+ typedef hera::DiagramTraits<Dgm>::PointType Point;
+ auto p1 = (Point*)buf1.ptr;
+ auto p2 = (Point*)buf2.ptr;
+ auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]);
+ auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]);
+
+ hera::AuctionParams<double> params;
+ return hera::wasserstein_dist(diag1, diag2, params);
+}
+
+PYBIND11_MODULE(hera, m) {
+ m.def("wasserstein_distance", &wasserstein_distance, R"pbdoc(
+ Compute the Wasserstein distance between two diagrams
+ )pbdoc");
+}