summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-05-06 14:13:14 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-05-06 14:13:14 +0200
commit5c5e2c3075235079fda94fc6a159cc5275f85a0c (patch)
tree347c977390da36264f27f81f0751f4be76d51279
parentdac92c5ae9da6aa21fdcd261737e08d6898dbbdc (diff)
Refactor the numpy -> C++ range conversion
If we want to reuse it for bottleneck...
-rw-r--r--src/python/gudhi/hera.cc31
1 files changed, 16 insertions, 15 deletions
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
index 63bbb075..5aec1806 100644
--- a/src/python/gudhi/hera.cc
+++ b/src/python/gudhi/hera.cc
@@ -22,7 +22,7 @@ namespace py = pybind11;
typedef py::array_t<double> Dgm;
// Get m[i,0] and m[i,1] as a pair
-auto pairify(void* p, ssize_t h, ssize_t w) {
+static auto pairify(void* p, ssize_t h, ssize_t w) {
return [=](ssize_t i){
char* birth = (char*)p + i * h;
char* death = birth + w;
@@ -30,28 +30,29 @@ auto pairify(void* p, ssize_t h, ssize_t w) {
};
}
+inline auto numpy_to_range_of_pairs(py::array_t<double> dgm) {
+ py::buffer_info buf = dgm.request();
+ // shape (n,2) or (0) for empty
+ if((buf.ndim!=2 || buf.shape[1]!=2) && (buf.ndim!=1 || buf.shape[0]!=0))
+ throw std::runtime_error("Diagram must be an array of size n x 2");
+ // In the case of shape (0), avoid reading non-existing strides[1] even if we won't use it.
+ ssize_t stride1 = buf.ndim == 2 ? buf.strides[1] : 0;
+ auto cnt = boost::counting_range<ssize_t>(0, buf.shape[0]);
+ return boost::adaptors::transform(cnt, pairify(buf.ptr, buf.strides[0], stride1));
+ // Be careful that the returned range cannot contain references to dead temporaries.
+}
+
double wasserstein_distance(
Dgm d1, Dgm d2,
double wasserstein_power, double internal_p,
double delta)
{
- py::buffer_info buf1 = d1.request();
- py::buffer_info buf2 = d2.request();
+ // I *think* the call to request() has to be before releasing the GIL.
+ auto diag1 = numpy_to_range_of_pairs(d1);
+ auto diag2 = numpy_to_range_of_pairs(d2);
py::gil_scoped_release release;
- // shape (n,2) or (0) for empty
- if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0))
- throw std::runtime_error("Diagram 1 must be an array of size n x 2");
- if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0))
- throw std::runtime_error("Diagram 2 must be an array of size n x 2");
- ssize_t stride11 = buf1.ndim == 2 ? buf1.strides[1] : 0;
- ssize_t stride21 = buf2.ndim == 2 ? buf2.strides[1] : 0;
- auto cnt1 = boost::counting_range<ssize_t>(0, buf1.shape[0]);
- auto diag1 = boost::adaptors::transform(cnt1, pairify(buf1.ptr, buf1.strides[0], stride11));
- auto cnt2 = boost::counting_range<ssize_t>(0, buf2.shape[0]);
- auto diag2 = boost::adaptors::transform(cnt2, pairify(buf2.ptr, buf2.strides[0], stride21));
-
hera::AuctionParams<double> params;
params.wasserstein_power = wasserstein_power;
// hera encodes infinity as -1...