summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2020-05-10 12:26:25 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2020-05-10 12:26:25 +0200
commitb94e81af10b8cb22ac338a2a8db349222ab1b57a (patch)
tree4909f1e0fe37f84381ff78c2e13bc56d25f0e03b /scripts
parent5f4b3ffcf7c8f90eee33a8504ede00ea52f79c0e (diff)
Added pyclblast bindings for the 3 batched routines
Diffstat (limited to 'scripts')
-rw-r--r--scripts/generator/generator/pyclblast.py70
-rw-r--r--scripts/generator/generator/routine.py43
2 files changed, 96 insertions, 17 deletions
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
index 47eb2eb4..b7ec348e 100644
--- a/scripts/generator/generator/pyclblast.py
+++ b/scripts/generator/generator/pyclblast.py
@@ -22,6 +22,16 @@ def to_np_dtype(flavour):
}[flavour.precision_name]
+def cl_type(flavour):
+ return {
+ "S": "cl_float",
+ "D": "cl_double",
+ "C": "cl_float2",
+ "Z": "cl_double2",
+ "H": "cl_half",
+ }[flavour.precision_name]
+
+
def scalar_cython_conversion(scalar, flavour):
scalar_type = flavour.alpha_cl if scalar == "alpha" else flavour.beta_cl
if scalar_type == "float":
@@ -39,7 +49,9 @@ def scalar_cython_conversion(scalar, flavour):
def generate_pyx(routine):
result = ""
- if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3"]:
+ if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3", "x"]:
+ if routine.level == "x" and routine.batched == 0:
+ return result # level-X routines that are non-batched are not supported at the moment
indent = " "
result += SEPARATOR + NL
@@ -80,6 +92,33 @@ def generate_pyx(routine):
result += buf + ", \"" + buf + "\")" + NL
result += NL
+ # Batched checks
+ if routine.batched == 1: # batched but not strided-batched
+ lists = [b + "_offsets" for b in buffers] + [s + "s" for s in routine.scalars]
+ result += indent + "if " + " != ".join(["len(" + l + ")" for l in lists]) + ":" + NL
+ result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: length of batch-sized arguments " + ", ".join(lists) + " should be equal\")" + NL
+ result += indent + "batch_count = len(" + lists[0] + ")" + NL
+ result += NL
+
+ # Batched list to pointer conversions
+ for buf in buffers:
+ result += indent + "cdef size_t *" + buf + "_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t))" + NL
+ result += indent + "for i in range(batch_count):" + NL
+ result += indent + indent + "" + buf + "_offsets_c[i] = " + buf + "_offsets[i]" + NL
+ for scalar in routine.scalars:
+ result += indent + "cdef void *" + scalar + "s_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype]))" + NL
+ result += indent + "for i in range(batch_count):" + NL
+ if_prefix = ""
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z", "H"]:
+ np_dtype = to_np_dtype(flavour)
+ result += indent + indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
+ scalar_converted = scalar_cython_conversion(scalar + "s[i]", flavour)
+ result += indent + indent + indent + "(<" + cl_type(flavour) + "*>" + scalar + "s_c)[i] = " + scalar_converted + NL
+ if_prefix = "el"
+
+ result += NL
+
# Buffer transformation
for buf in buffers:
result += indent + "cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL
@@ -108,11 +147,22 @@ def generate_pyx(routine):
for flavour in routine.flavours:
if flavour.precision_name in ["S", "D", "C", "Z", "H"]:
np_dtype = to_np_dtype(flavour)
- argument_names = [x.
- replace("layout", "CLBlastLayoutRowMajor").
- replace("alpha", scalar_cython_conversion("alpha", flavour)).
- replace("beta", scalar_cython_conversion("beta", flavour))
- for x in routine.arguments()]
+ if routine.batched != 1: # regular or strided-batched
+ argument_names = [x.
+ replace("layout", "CLBlastLayoutRowMajor").
+ replace("alpha", scalar_cython_conversion("alpha", flavour)).
+ replace("beta", scalar_cython_conversion("beta", flavour))
+ for x in routine.arguments()]
+ else: # batched but not strided-batched
+ argument_names = [x.
+ replace("layout", "CLBlastLayoutRowMajor").
+ replace("_cpp", "_c").
+ replace("_offsets", "_offsets_c").
+ replace("alphas_c", "<" + cl_type(flavour) + "*>alphas_c").
+ replace("betas_c", "<" + cl_type(flavour) + "*>betas_c")
+ for x in routine.arguments()]
+ if routine.batched > 0:
+ argument_names.append("batch_count")
result += indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
result += indent + indent + "err = CLBlast" + flavour.name + routine.plain_name()
result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL
@@ -120,6 +170,14 @@ def generate_pyx(routine):
result += indent + "else:" + NL
result += indent + indent + "raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL
+ result += NL
+
+ # Cleaning up
+ if routine.batched == 1: # batched but not strided-batched
+ for array in [b + "_offset" for b in buffers] + routine.scalars:
+ result += indent + "PyMem_Free(" + array + "s_c)" + NL
+ result += NL
+
result += indent + "if err != CLBlastSuccess:" + NL
result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL
result += indent + "return cl.Event.from_int_ptr(<size_t>event)" + NL
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 3b5a6b76..8b6ab57f 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -825,17 +825,37 @@ class Routine:
"""Arguments for the Python wrapper pyclblast"""
result = list()
result.extend(self.sizes)
+ if self.batched == 2: # strided batched
+ result.append("batch_count")
buffers = self.inputs + self.outputs
result.extend(buffers[:])
- for buf in buffers:
- if buf in self.buffers_matrix():
- result.append(buf + "_ld")
- for buf in buffers:
- if buf in self.buffers_vector():
- result.append(buf + "_inc = 1")
- for scalar in self.scalars:
- default = "1.0" if scalar == "alpha" else "0.0"
- result.append(scalar + " = " + default)
+ if self.batched != 1: # regular or strided-batched
+ for buf in buffers:
+ if buf in self.buffers_matrix():
+ result.append(buf + "_ld")
+ for buf in buffers:
+ if buf in self.buffers_vector():
+ result.append(buf + "_inc = 1")
+ if self.batched == 2: # strided batched
+ for buf in buffers:
+ if buf in self.buffers_matrix():
+ result.append(buf + "_stride")
+ for scalar in self.scalars:
+ if scalar != "":
+ default = "1.0" if scalar == "alpha" else "0.0"
+ result.append(scalar + " = " + default)
+ else: # batched but not strided-batched
+ for scalar in self.scalars:
+ result.append(scalar + "s")
+ for buf in buffers:
+ if buf in self.buffers_matrix():
+ result.append(buf + "_ld")
+ for buf in buffers:
+ if buf in self.buffers_vector() + self.buffers_matrix():
+ result.append(buf + "_offsets")
+ for buf in buffers:
+ if buf in self.buffers_vector():
+ result.append(buf + "_inc = 1")
for option in self.options:
if option == "a_transpose":
result.append("a_transp = False")
@@ -849,8 +869,9 @@ class Routine:
result.append("lower_triangle = False")
if option == "diagonal":
result.append("unit_diagonal = False")
- for buf in buffers:
- result.append(buf + "_offset = 0")
+ if self.batched != 1:
+ for buf in buffers:
+ result.append(buf + "_offset = 0")
return result
def requirements_doc(self):