summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
authorHind <hind.montassif@gmail.com>2021-04-23 11:27:59 +0200
committerHind <hind.montassif@gmail.com>2021-04-23 11:27:59 +0200
commitdb7ce3487e526741c0408b00c2cffda0048b0026 (patch)
tree886a7295a170823c4ae11c8f117a431fd72de39f /src/python
parent45917ecf17acacfede909994d7b3a78fc18355da (diff)
Make adjustments according to the received reviews
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt2
-rw-r--r--src/python/gudhi/random_point_generators.cc45
2 files changed, 21 insertions, 26 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 8baf0f02..87f10a1a 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -43,10 +43,10 @@ endfunction( add_gudhi_debug_info )
if(PYTHONINTERP_FOUND)
if(PYBIND11_FOUND)
add_gudhi_debug_info("Pybind11 version ${PYBIND11_VERSION}")
- set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'random_point_generators', ")
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'bottleneck', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'hera', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'clustering', ")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'random_point_generators', ")
endif()
if(CYTHON_FOUND)
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'off_reader', ")
diff --git a/src/python/gudhi/random_point_generators.cc b/src/python/gudhi/random_point_generators.cc
index 39b09a6d..4306ba98 100644
--- a/src/python/gudhi/random_point_generators.cc
+++ b/src/python/gudhi/random_point_generators.cc
@@ -21,30 +21,25 @@ namespace py = pybind11;
typedef CGAL::Epick_d< CGAL::Dynamic_dimension_tag > Kern;
template <typename Kernel>
-py::array_t<double> generate_points_on_sphere(py::object num_points, py::object dim, py::object radius) {
- int npoints = num_points.cast<int>();
- int d = dim.cast<int>();
- double rad = radius.cast<double>();
-
- py::gil_scoped_release release;
-
- auto points_generated = Gudhi::generate_points_on_sphere_d<Kernel>(npoints, d, rad);
-
- py::gil_scoped_acquire acquire;
-
- py::array_t<double> points({npoints, d});
+py::array_t<double> generate_points_on_sphere(size_t num_points, int dim, double radius) {
+
+ py::array_t<double> points({(int)num_points, dim});
py::buffer_info buf = points.request();
-
double *ptr = static_cast<double *>(buf.ptr);
- assert(npoints == buf.shape[0]);
- assert(d == buf.shape[1]);
+ assert(num_points == buf.shape[0]);
+ assert(dim == buf.shape[1]);
-
- for (size_t i = 0; i < (size_t)npoints; i++)
- for (size_t j = 0; j < (size_t)d; j++)
- ptr[i*d+j] = points_generated.at(i).at(j);
+ std::vector<typename Kernel::Point_d> points_generated;
+ {
+ py::gil_scoped_release release;
+ points_generated = Gudhi::generate_points_on_sphere_d<Kernel>(num_points, dim, radius);
+
+ for (size_t i = 0; i < num_points; i++)
+ for (size_t j = 0; j < (size_t)dim; j++)
+ ptr[i*dim+j] = points_generated[i][j];
+ }
return points;
}
@@ -52,17 +47,17 @@ py::array_t<double> generate_points_on_sphere(py::object num_points, py::object
PYBIND11_MODULE(random_point_generators, m) {
m.attr("__license__") = "LGPL v3";
m.def("generate_points_on_sphere_d", &generate_points_on_sphere<Kern>,
- py::arg("num_points"), py::arg("dim"), py::arg("radius"),
+ py::arg("num_points"), py::arg("dim"), py::arg("radius") = 1,
R"pbdoc(
- Generate points on a sphere
+ Generate random i.i.d. points uniformly on a (d-1)-sphere in Rd
:param num_points: The number of points to be generated.
- :type num_points: integer
- :param dim: The sphere dimension.
+ :type num_points: unsigned integer
+ :param dim: The dimension.
:type dim: integer
- :param radius: The sphere radius.
+ :param radius: The radius.
:type radius: float
- :rtype: numpy array of points
+ :rtype: numpy array of float
:returns: the generated points on a sphere.
)pbdoc");
}