summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2021-10-27 09:58:52 +0200
committerHind-M <hind.montassif@gmail.com>2021-10-27 09:58:52 +0200
commit3a29558decccafe0b07dbf07d66f1410df6c187f (patch)
tree7d1c2d37bb64621cf35a52ef3a72dbfa1ad489b1 /src
parentbb8c4994b89fb6bfdd80b76912acadf6197f93cc (diff)
Replace itertools in grid torus generation function with something faster in most general use cases
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/datasets/generators/points.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/python/gudhi/datasets/generators/points.py b/src/python/gudhi/datasets/generators/points.py
index 7f4667af..cf97777d 100644
--- a/src/python/gudhi/datasets/generators/points.py
+++ b/src/python/gudhi/datasets/generators/points.py
@@ -8,7 +8,6 @@
# - YYYY/MM Author: Description of the modification
import numpy as np
-import itertools
from ._points import ctorus
from ._points import sphere
@@ -29,10 +28,11 @@ def _generate_grid_points_on_torus(n_samples, dim):
n_samples_grid = int((n_samples+.5)**(1./dim)) # add .5 to avoid rounding down with numerical approximations
alpha = np.linspace(0, 2*np.pi, n_samples_grid, endpoint=False)
- array_points_inter = np.column_stack([np.cos(alpha), np.sin(alpha)])
- array_points = np.array(list(itertools.product(array_points_inter, repeat=dim))).reshape(-1, 2*dim)
-
- return array_points
+ array_points = np.column_stack([np.cos(alpha), np.sin(alpha)])
+ array_points_idx = np.empty([n_samples_grid]*dim + [dim], dtype=int)
+ for i, x in enumerate(np.ix_(*([np.arange(n_samples_grid)]*dim))):
+ array_points_idx[...,i] = x
+ return array_points[array_points_idx].reshape(-1, 2*dim)
def torus(n_samples, dim, sample='random'):
"""