summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-03 20:43:11 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-03 20:43:11 +0200
commitac7917ab2cbece048e554e32cc653c14440dbcc0 (patch)
tree07a3fc833de7b3618bc2df5bba4104bfb3b41bf9
parent07a017ca26238847e9d9ab75dcb17e52c81e6865 (diff)
Fewer copies and no GIL for hera
Now the input arrays are not copied as long as they use a float64 data type, even if they are not contiguous. That's not important here, but I wanted an example of how to do it. More importantly, no need to hold the GIL. I was too lazy to benchmark to see if that changed anything...
-rw-r--r--src/python/gudhi/hera.cc28
1 files changed, 20 insertions, 8 deletions
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
index 0d562b4c..50d49c77 100644
--- a/src/python/gudhi/hera.cc
+++ b/src/python/gudhi/hera.cc
@@ -11,14 +11,24 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
-#include <boost/range/iterator_range.hpp>
+#include <boost/range/counting_range.hpp>
+#include <boost/range/adaptor/transformed.hpp>
#include <wasserstein.h> // Hera
-#include <array>
+#include <utility>
namespace py = pybind11;
-typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm;
+typedef py::array_t<double> Dgm;
+
+// Get m[i,0] and m[i,1] as a pair
+auto pairify(void* p, ssize_t h, ssize_t w) {
+ return [=](ssize_t i){
+ char* birth = (char*)p + i * h;
+ char* death = birth + w;
+ return std::make_pair(*(double*)birth, *(double*)death);
+ };
+}
double wasserstein_distance(
Dgm d1, Dgm d2,
@@ -27,16 +37,18 @@ double wasserstein_distance(
{
py::buffer_info buf1 = d1.request();
py::buffer_info buf2 = d2.request();
+
+ py::gil_scoped_release release;
+
// shape (n,2) or (0) for empty
if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0))
throw std::runtime_error("Diagram 1 must be an array of size n x 2");
if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0))
throw std::runtime_error("Diagram 2 must be an array of size n x 2");
- typedef std::array<double, 2> Point;
- auto p1 = (Point*)buf1.ptr;
- auto p2 = (Point*)buf2.ptr;
- auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]);
- auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]);
+ auto cnt1 = boost::counting_range<ssize_t>(0, buf1.shape[0]);
+ auto diag1 = boost::adaptors::transform(cnt1, pairify(buf1.ptr, buf1.strides[0], buf1.strides[1]));
+ auto cnt2 = boost::counting_range<ssize_t>(0, buf2.shape[0]);
+ auto diag2 = boost::adaptors::transform(cnt2, pairify(buf2.ptr, buf2.strides[0], buf2.strides[1]));
hera::AuctionParams<double> params;
params.wasserstein_power = wasserstein_power;