summaryrefslogtreecommitdiff
path: root/src/python/gudhi/datasets/generators
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/datasets/generators')
-rw-r--r--src/python/gudhi/datasets/generators/_points.cc4
-rw-r--r--src/python/gudhi/datasets/generators/points.py5
2 files changed, 5 insertions, 4 deletions
diff --git a/src/python/gudhi/datasets/generators/_points.cc b/src/python/gudhi/datasets/generators/_points.cc
index 6bbdf284..3d38ff90 100644
--- a/src/python/gudhi/datasets/generators/_points.cc
+++ b/src/python/gudhi/datasets/generators/_points.cc
@@ -48,6 +48,10 @@ py::array_t<double> generate_points_on_sphere(size_t n_samples, int ambient_dim,
py::array_t<double> generate_points_on_torus(size_t n_samples, int dim, std::string sample) {
+ if ( (sample != "random") && (sample != "grid")) {
+ throw pybind11::value_error("This sample type is not supported");
+ }
+
std::vector<typename Kern::Point_d> points_generated;
{
diff --git a/src/python/gudhi/datasets/generators/points.py b/src/python/gudhi/datasets/generators/points.py
index 3870dea6..daada486 100644
--- a/src/python/gudhi/datasets/generators/points.py
+++ b/src/python/gudhi/datasets/generators/points.py
@@ -23,7 +23,7 @@ def _generate_random_points_on_torus(n_samples, dim):
def _generate_grid_points_on_torus(n_samples, dim):
# Generate points on a dim-torus as a grid
- n_samples_grid = int(n_samples**(1./dim))
+ n_samples_grid = int((n_samples+.5)**(1./dim)) # add .5 to avoid rounding down with numerical approximations
alpha = np.linspace(0, 2*np.pi, n_samples_grid, endpoint=False)
array_points_inter = np.column_stack([np.cos(alpha), np.sin(alpha)])
@@ -45,12 +45,9 @@ def torus(n_samples, dim, sample='random'):
"""
if sample == 'random':
# Generate points randomly
- print("Sample is random")
return _generate_random_points_on_torus(n_samples, dim)
elif sample == 'grid':
# Generate points on a grid
- print("Sample is grid")
return _generate_grid_points_on_torus(n_samples, dim)
else:
raise ValueError("Sample type '{}' is not supported".format(sample))
- return