From 3a29558decccafe0b07dbf07d66f1410df6c187f Mon Sep 17 00:00:00 2001 From: Hind-M Date: Wed, 27 Oct 2021 09:58:52 +0200 Subject: Replace itertools in grid torus generation function with something faster in most general use cases --- src/python/gudhi/datasets/generators/points.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/python/gudhi/datasets/generators/points.py b/src/python/gudhi/datasets/generators/points.py index 7f4667af..cf97777d 100644 --- a/src/python/gudhi/datasets/generators/points.py +++ b/src/python/gudhi/datasets/generators/points.py @@ -8,7 +8,6 @@ # - YYYY/MM Author: Description of the modification import numpy as np -import itertools from ._points import ctorus from ._points import sphere @@ -29,10 +28,11 @@ def _generate_grid_points_on_torus(n_samples, dim): n_samples_grid = int((n_samples+.5)**(1./dim)) # add .5 to avoid rounding down with numerical approximations alpha = np.linspace(0, 2*np.pi, n_samples_grid, endpoint=False) - array_points_inter = np.column_stack([np.cos(alpha), np.sin(alpha)]) - array_points = np.array(list(itertools.product(array_points_inter, repeat=dim))).reshape(-1, 2*dim) - - return array_points + array_points = np.column_stack([np.cos(alpha), np.sin(alpha)]) + array_points_idx = np.empty([n_samples_grid]*dim + [dim], dtype=int) + for i, x in enumerate(np.ix_(*([np.arange(n_samples_grid)]*dim))): + array_points_idx[...,i] = x + return array_points[array_points_idx].reshape(-1, 2*dim) def torus(n_samples, dim, sample='random'): """ -- cgit v1.2.3