diff options
author | Hind <hind.montassif@gmail.com> | 2021-04-23 11:27:59 +0200 |
---|---|---|
committer | Hind <hind.montassif@gmail.com> | 2021-04-23 11:27:59 +0200 |
commit | db7ce3487e526741c0408b00c2cffda0048b0026 (patch) | |
tree | 886a7295a170823c4ae11c8f117a431fd72de39f /src | |
parent | 45917ecf17acacfede909994d7b3a78fc18355da (diff) |
Make adjustments according to the received reviews
Diffstat (limited to 'src')
-rw-r--r-- | src/python/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/python/gudhi/random_point_generators.cc | 45 |
2 files changed, 21 insertions, 26 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 8baf0f02..87f10a1a 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -43,10 +43,10 @@ endfunction( add_gudhi_debug_info ) if(PYTHONINTERP_FOUND) if(PYBIND11_FOUND) add_gudhi_debug_info("Pybind11 version ${PYBIND11_VERSION}") - set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'random_point_generators', ") set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'bottleneck', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'hera', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'clustering', ") + set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'random_point_generators', ") endif() if(CYTHON_FOUND) set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'off_reader', ") diff --git a/src/python/gudhi/random_point_generators.cc b/src/python/gudhi/random_point_generators.cc index 39b09a6d..4306ba98 100644 --- a/src/python/gudhi/random_point_generators.cc +++ b/src/python/gudhi/random_point_generators.cc @@ -21,30 +21,25 @@ namespace py = pybind11; typedef CGAL::Epick_d< CGAL::Dynamic_dimension_tag > Kern; template <typename Kernel> -py::array_t<double> generate_points_on_sphere(py::object num_points, py::object dim, py::object radius) { - int npoints = num_points.cast<int>(); - int d = dim.cast<int>(); - double rad = radius.cast<double>(); - - py::gil_scoped_release release; - - auto points_generated = Gudhi::generate_points_on_sphere_d<Kernel>(npoints, d, rad); - - py::gil_scoped_acquire acquire; - - py::array_t<double> points({npoints, d}); +py::array_t<double> generate_points_on_sphere(size_t num_points, int dim, double radius) { + + py::array_t<double> points({(int)num_points, dim}); py::buffer_info buf = points.request(); - double *ptr = static_cast<double *>(buf.ptr); - assert(npoints == buf.shape[0]); - assert(d == buf.shape[1]); + assert(num_points == buf.shape[0]); + assert(dim == buf.shape[1]); - - for (size_t i = 0; i < (size_t)npoints; i++) - for (size_t j = 0; j < (size_t)d; j++) - ptr[i*d+j] = points_generated.at(i).at(j); + std::vector<typename Kernel::Point_d> points_generated; + { + py::gil_scoped_release release; + points_generated = Gudhi::generate_points_on_sphere_d<Kernel>(num_points, dim, radius); + + for (size_t i = 0; i < num_points; i++) + for (size_t j = 0; j < (size_t)dim; j++) + ptr[i*dim+j] = points_generated[i][j]; + } return points; } @@ -52,17 +47,17 @@ py::array_t<double> generate_points_on_sphere(py::object num_points, py::object PYBIND11_MODULE(random_point_generators, m) { m.attr("__license__") = "LGPL v3"; m.def("generate_points_on_sphere_d", &generate_points_on_sphere<Kern>, - py::arg("num_points"), py::arg("dim"), py::arg("radius"), + py::arg("num_points"), py::arg("dim"), py::arg("radius") = 1, R"pbdoc( - Generate points on a sphere + Generate random i.i.d. points uniformly on a (d-1)-sphere in Rd :param num_points: The number of points to be generated. - :type num_points: integer - :param dim: The sphere dimension. + :type num_points: unsigned integer + :param dim: The dimension. :type dim: integer - :param radius: The sphere radius. + :param radius: The radius. :type radius: float - :rtype: numpy array of points + :rtype: numpy array of float :returns: the generated points on a sphere. )pbdoc"); } |