diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-05-07 15:18:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-07 15:18:55 +0200 |
commit | 89b34f069e632a8fea0642556a4010de821ed6c9 (patch) | |
tree | 12d906ab4a4914aa6ecd6cc6ba9a2be397f92177 /src/python | |
parent | cb969cf6b98c4cef0a08eee698c81b0510206b84 (diff) | |
parent | 5c5e2c3075235079fda94fc6a159cc5275f85a0c (diff) |
Merge pull request #302 from mglisse/hera-gil
Fewer copies and no GIL for hera
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/gudhi/hera.cc | 45 |
1 files changed, 30 insertions, 15 deletions
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc index 0d562b4c..5aec1806 100644 --- a/src/python/gudhi/hera.cc +++ b/src/python/gudhi/hera.cc @@ -11,32 +11,47 @@ #include <pybind11/pybind11.h> #include <pybind11/numpy.h> -#include <boost/range/iterator_range.hpp> +#include <boost/range/counting_range.hpp> +#include <boost/range/adaptor/transformed.hpp> #include <wasserstein.h> // Hera -#include <array> +#include <utility> namespace py = pybind11; -typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm; +typedef py::array_t<double> Dgm; + +// Get m[i,0] and m[i,1] as a pair +static auto pairify(void* p, ssize_t h, ssize_t w) { + return [=](ssize_t i){ + char* birth = (char*)p + i * h; + char* death = birth + w; + return std::make_pair(*(double*)birth, *(double*)death); + }; +} + +inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) { + py::buffer_info buf = dgm.request(); + // shape (n,2) or (0) for empty + if((buf.ndim!=2 || buf.shape[1]!=2) && (buf.ndim!=1 || buf.shape[0]!=0)) + throw std::runtime_error("Diagram must be an array of size n x 2"); + // In the case of shape (0), avoid reading non-existing strides[1] even if we won't use it. + ssize_t stride1 = buf.ndim == 2 ? buf.strides[1] : 0; + auto cnt = boost::counting_range<ssize_t>(0, buf.shape[0]); + return boost::adaptors::transform(cnt, pairify(buf.ptr, buf.strides[0], stride1)); + // Be careful that the returned range cannot contain references to dead temporaries. +} double wasserstein_distance( Dgm d1, Dgm d2, double wasserstein_power, double internal_p, double delta) { - py::buffer_info buf1 = d1.request(); - py::buffer_info buf2 = d2.request(); - // shape (n,2) or (0) for empty - if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0)) - throw std::runtime_error("Diagram 1 must be an array of size n x 2"); - if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0)) - throw std::runtime_error("Diagram 2 must be an array of size n x 2"); - typedef std::array<double, 2> Point; - auto p1 = (Point*)buf1.ptr; - auto p2 = (Point*)buf2.ptr; - auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]); - auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]); + // I *think* the call to request() has to be before releasing the GIL. + auto diag1 = numpy_to_range_of_pairs(d1); + auto diag2 = numpy_to_range_of_pairs(d2); + + py::gil_scoped_release release; hera::AuctionParams<double> params; params.wasserstein_power = wasserstein_power; |