summaryrefslogtreecommitdiff
path: root/src/python/gudhi/hera.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/hera.cc')
-rw-r--r--src/python/gudhi/hera.cc27
1 files changed, 6 insertions, 21 deletions
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
index 0d562b4c..ea80a9a8 100644
--- a/src/python/gudhi/hera.cc
+++ b/src/python/gudhi/hera.cc
@@ -8,35 +8,20 @@
* - YYYY/MM Author: Description of the modification
*/
-#include <pybind11/pybind11.h>
-#include <pybind11/numpy.h>
-
-#include <boost/range/iterator_range.hpp>
-
#include <wasserstein.h> // Hera
-#include <array>
-
-namespace py = pybind11;
-typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm;
+#include <pybind11_diagram_utils.h>
double wasserstein_distance(
Dgm d1, Dgm d2,
double wasserstein_power, double internal_p,
double delta)
{
- py::buffer_info buf1 = d1.request();
- py::buffer_info buf2 = d2.request();
- // shape (n,2) or (0) for empty
- if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0))
- throw std::runtime_error("Diagram 1 must be an array of size n x 2");
- if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0))
- throw std::runtime_error("Diagram 2 must be an array of size n x 2");
- typedef std::array<double, 2> Point;
- auto p1 = (Point*)buf1.ptr;
- auto p2 = (Point*)buf2.ptr;
- auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]);
- auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]);
+ // I *think* the call to request() has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1);
+ auto diag2 = numpy_to_range_of_pairs(d2);
+
+ py::gil_scoped_release release;
hera::AuctionParams<double> params;
params.wasserstein_power = wasserstein_power;