diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-05-09 10:37:00 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-05-09 10:37:00 +0200 |
commit | 894462a364dd5d4bf4a5250c0c3c075c561fb174 (patch) | |
tree | 02f5faedfe1e80335f3414fd09c0c18a45854129 /src/python/gudhi/hera.cc | |
parent | 5ad8f41550d94988214fbf128a179d918635c3cf (diff) | |
parent | b91fa852b141a5f8b4a4915849b35fbdb6993772 (diff) |
Merge remote-tracking branch 'origin/master' into nogil1
Diffstat (limited to 'src/python/gudhi/hera.cc')
-rw-r--r-- | src/python/gudhi/hera.cc | 27 |
1 files changed, 6 insertions, 21 deletions
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc index 0d562b4c..ea80a9a8 100644 --- a/src/python/gudhi/hera.cc +++ b/src/python/gudhi/hera.cc @@ -8,35 +8,20 @@ * - YYYY/MM Author: Description of the modification */ -#include <pybind11/pybind11.h> -#include <pybind11/numpy.h> - -#include <boost/range/iterator_range.hpp> - #include <wasserstein.h> // Hera -#include <array> - -namespace py = pybind11; -typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm; +#include <pybind11_diagram_utils.h> double wasserstein_distance( Dgm d1, Dgm d2, double wasserstein_power, double internal_p, double delta) { - py::buffer_info buf1 = d1.request(); - py::buffer_info buf2 = d2.request(); - // shape (n,2) or (0) for empty - if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0)) - throw std::runtime_error("Diagram 1 must be an array of size n x 2"); - if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0)) - throw std::runtime_error("Diagram 2 must be an array of size n x 2"); - typedef std::array<double, 2> Point; - auto p1 = (Point*)buf1.ptr; - auto p2 = (Point*)buf2.ptr; - auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]); - auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]); + // I *think* the call to request() has to be before releasing the GIL. + auto diag1 = numpy_to_range_of_pairs(d1); + auto diag2 = numpy_to_range_of_pairs(d2); + + py::gil_scoped_release release; hera::AuctionParams<double> params; params.wasserstein_power = wasserstein_power; |