summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-07 15:18:55 +0200
committerGitHub <noreply@github.com>2020-05-07 15:18:55 +0200
commit89b34f069e632a8fea0642556a4010de821ed6c9 (patch)
tree12d906ab4a4914aa6ecd6cc6ba9a2be397f92177
parentcb969cf6b98c4cef0a08eee698c81b0510206b84 (diff)
parent5c5e2c3075235079fda94fc6a159cc5275f85a0c (diff)
Merge pull request #302 from mglisse/hera-gil
Fewer copies and no GIL for hera
-rw-r--r--src/python/gudhi/hera.cc45
1 files changed, 30 insertions, 15 deletions
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
index 0d562b4c..5aec1806 100644
--- a/src/python/gudhi/hera.cc
+++ b/src/python/gudhi/hera.cc
@@ -11,32 +11,47 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
-#include <boost/range/iterator_range.hpp>
+#include <boost/range/counting_range.hpp>
+#include <boost/range/adaptor/transformed.hpp>
#include <wasserstein.h> // Hera
-#include <array>
+#include <utility>
namespace py = pybind11;
-typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm;
+typedef py::array_t<double> Dgm;
+
+// Get m[i,0] and m[i,1] as a pair
+static auto pairify(void* p, ssize_t h, ssize_t w) {
+ return [=](ssize_t i){
+ char* birth = (char*)p + i * h;
+ char* death = birth + w;
+ return std::make_pair(*(double*)birth, *(double*)death);
+ };
+}
+
+inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) {
+ py::buffer_info buf = dgm.request();
+ // shape (n,2) or (0) for empty
+ if((buf.ndim!=2 || buf.shape[1]!=2) && (buf.ndim!=1 || buf.shape[0]!=0))
+ throw std::runtime_error("Diagram must be an array of size n x 2");
+ // In the case of shape (0), avoid reading non-existing strides[1] even if we won't use it.
+ ssize_t stride1 = buf.ndim == 2 ? buf.strides[1] : 0;
+ auto cnt = boost::counting_range<ssize_t>(0, buf.shape[0]);
+ return boost::adaptors::transform(cnt, pairify(buf.ptr, buf.strides[0], stride1));
+ // Be careful that the returned range cannot contain references to dead temporaries.
+}
double wasserstein_distance(
Dgm d1, Dgm d2,
double wasserstein_power, double internal_p,
double delta)
{
- py::buffer_info buf1 = d1.request();
- py::buffer_info buf2 = d2.request();
- // shape (n,2) or (0) for empty
- if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0))
- throw std::runtime_error("Diagram 1 must be an array of size n x 2");
- if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0))
- throw std::runtime_error("Diagram 2 must be an array of size n x 2");
- typedef std::array<double, 2> Point;
- auto p1 = (Point*)buf1.ptr;
- auto p2 = (Point*)buf2.ptr;
- auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]);
- auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]);
+ // I *think* the call to request() has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1);
+ auto diag2 = numpy_to_range_of_pairs(d2);
+
+ py::gil_scoped_release release;
hera::AuctionParams<double> params;
params.wasserstein_power = wasserstein_power;