summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHind-M <hind.montassif@gmail.com>2021-10-21 17:08:43 +0200
committerHind-M <hind.montassif@gmail.com>2021-10-21 17:08:43 +0200
commit36959807d5091b79aedabbc67c363dd761c9d5ee (patch)
treec40f1acf1c69897dc4ed39f5ada22d683b1598ed /src
parente23ca84fadcc2c65fd8cf2d141be804bf18b2fd6 (diff)
Factorize cpp and python torus tests implementations
Diffstat (limited to 'src')
-rwxr-xr-xsrc/python/test/test_datasets_generators.py26
1 files changed, 10 insertions, 16 deletions
diff --git a/src/python/test/test_datasets_generators.py b/src/python/test/test_datasets_generators.py
index 4c087c57..e2d300e0 100755
--- a/src/python/test/test_datasets_generators.py
+++ b/src/python/test/test_datasets_generators.py
@@ -18,22 +18,16 @@ def test_sphere():
with pytest.raises(ValueError):
points.sphere(n_samples = 10, ambient_dim = 2, radius = 1., sample = 'other')
-def test_torus():
- assert points.ctorus(n_samples = 64, dim = 3, sample = 'random').shape == (64, 6)
- assert points.ctorus(n_samples = 64, dim = 3, sample = 'grid').shape == (64, 6)
-
- assert points.ctorus(n_samples = 10, dim = 4, sample = 'random').shape == (10, 8)
- assert points.ctorus(n_samples = 10, dim = 4, sample = 'grid').shape == (1, 8)
-
- with pytest.raises(ValueError):
- points.ctorus(n_samples = 10, dim = 4, sample = 'other')
+def _basic_torus(impl):
+ assert impl(n_samples = 64, dim = 3, sample = 'random').shape == (64, 6)
+ assert impl(n_samples = 64, dim = 3, sample = 'grid').shape == (64, 6)
-def test_torus_full_python():
- assert points.torus(n_samples = 64, dim = 3, sample = 'random').shape == (64, 6)
- assert points.torus(n_samples = 64, dim = 3, sample = 'grid').shape == (64, 6)
-
- assert points.torus(n_samples = 10, dim = 4, sample = 'random').shape == (10, 8)
- assert points.torus(n_samples = 10, dim = 4, sample = 'grid').shape == (1, 8)
+ assert impl(n_samples = 10, dim = 4, sample = 'random').shape == (10, 8)
+ assert impl(n_samples = 10, dim = 4, sample = 'grid').shape == (1, 8)
with pytest.raises(ValueError):
- points.torus(n_samples = 10, dim = 4, sample = 'other')
+ impl(n_samples = 10, dim = 4, sample = 'other')
+
+def test_torus():
+ for torus_impl in [points.torus, points.ctorus]:
+ _basic_torus(torus_impl)