From b94e81af10b8cb22ac338a2a8db349222ab1b57a Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 10 May 2020 12:26:25 +0200 Subject: Added pyclblast bindings for the 3 batched routines --- scripts/generator/generator/pyclblast.py | 70 +++++++++++++++++++++++++++++--- scripts/generator/generator/routine.py | 43 +++++++++++++++----- 2 files changed, 96 insertions(+), 17 deletions(-) (limited to 'scripts/generator') diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py index 47eb2eb4..b7ec348e 100644 --- a/scripts/generator/generator/pyclblast.py +++ b/scripts/generator/generator/pyclblast.py @@ -22,6 +22,16 @@ def to_np_dtype(flavour): }[flavour.precision_name] +def cl_type(flavour): + return { + "S": "cl_float", + "D": "cl_double", + "C": "cl_float2", + "Z": "cl_double2", + "H": "cl_half", + }[flavour.precision_name] + + def scalar_cython_conversion(scalar, flavour): scalar_type = flavour.alpha_cl if scalar == "alpha" else flavour.beta_cl if scalar_type == "float": @@ -39,7 +49,9 @@ def scalar_cython_conversion(scalar, flavour): def generate_pyx(routine): result = "" - if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3"]: + if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3", "x"]: + if routine.level == "x" and routine.batched == 0: + return result # level-X routines that are non-batched are not supported at the moment indent = " " result += SEPARATOR + NL @@ -80,6 +92,33 @@ def generate_pyx(routine): result += buf + ", \"" + buf + "\")" + NL result += NL + # Batched checks + if routine.batched == 1: # batched but not strided-batched + lists = [b + "_offsets" for b in buffers] + [s + "s" for s in routine.scalars] + result += indent + "if " + " != ".join(["len(" + l + ")" for l in lists]) + ":" + NL + result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: length of batch-sized arguments " + ", ".join(lists) + " should be equal\")" + NL + result += indent + "batch_count = len(" + lists[0] + ")" + NL + result += NL + + # Batched list to pointer conversions + for buf in buffers: + result += indent + "cdef size_t *" + buf + "_offsets_c = PyMem_Malloc(batch_count * sizeof(size_t))" + NL + result += indent + "for i in range(batch_count):" + NL + result += indent + indent + "" + buf + "_offsets_c[i] = " + buf + "_offsets[i]" + NL + for scalar in routine.scalars: + result += indent + "cdef void *" + scalar + "s_c = PyMem_Malloc(batch_count * sizeof(dtype_size[dtype]))" + NL + result += indent + "for i in range(batch_count):" + NL + if_prefix = "" + for flavour in routine.flavours: + if flavour.precision_name in ["S", "D", "C", "Z", "H"]: + np_dtype = to_np_dtype(flavour) + result += indent + indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL + scalar_converted = scalar_cython_conversion(scalar + "s[i]", flavour) + result += indent + indent + indent + "(<" + cl_type(flavour) + "*>" + scalar + "s_c)[i] = " + scalar_converted + NL + if_prefix = "el" + + result += NL + # Buffer transformation for buf in buffers: result += indent + "cdef cl_mem " + buf + "_buffer = " + buf + ".base_data.int_ptr" + NL @@ -108,11 +147,22 @@ def generate_pyx(routine): for flavour in routine.flavours: if flavour.precision_name in ["S", "D", "C", "Z", "H"]: np_dtype = to_np_dtype(flavour) - argument_names = [x. - replace("layout", "CLBlastLayoutRowMajor"). - replace("alpha", scalar_cython_conversion("alpha", flavour)). - replace("beta", scalar_cython_conversion("beta", flavour)) - for x in routine.arguments()] + if routine.batched != 1: # regular or strided-batched + argument_names = [x. + replace("layout", "CLBlastLayoutRowMajor"). + replace("alpha", scalar_cython_conversion("alpha", flavour)). + replace("beta", scalar_cython_conversion("beta", flavour)) + for x in routine.arguments()] + else: # batched but not strided-batched + argument_names = [x. + replace("layout", "CLBlastLayoutRowMajor"). + replace("_cpp", "_c"). + replace("_offsets", "_offsets_c"). + replace("alphas_c", "<" + cl_type(flavour) + "*>alphas_c"). + replace("betas_c", "<" + cl_type(flavour) + "*>betas_c") + for x in routine.arguments()] + if routine.batched > 0: + argument_names.append("batch_count") result += indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL result += indent + indent + "err = CLBlast" + flavour.name + routine.plain_name() result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL @@ -120,6 +170,14 @@ def generate_pyx(routine): result += indent + "else:" + NL result += indent + indent + "raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL + result += NL + + # Cleaning up + if routine.batched == 1: # batched but not strided-batched + for array in [b + "_offset" for b in buffers] + routine.scalars: + result += indent + "PyMem_Free(" + array + "s_c)" + NL + result += NL + result += indent + "if err != CLBlastSuccess:" + NL result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL result += indent + "return cl.Event.from_int_ptr(event)" + NL diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 3b5a6b76..8b6ab57f 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -825,17 +825,37 @@ class Routine: """Arguments for the Python wrapper pyclblast""" result = list() result.extend(self.sizes) + if self.batched == 2: # strided batched + result.append("batch_count") buffers = self.inputs + self.outputs result.extend(buffers[:]) - for buf in buffers: - if buf in self.buffers_matrix(): - result.append(buf + "_ld") - for buf in buffers: - if buf in self.buffers_vector(): - result.append(buf + "_inc = 1") - for scalar in self.scalars: - default = "1.0" if scalar == "alpha" else "0.0" - result.append(scalar + " = " + default) + if self.batched != 1: # regular or strided-batched + for buf in buffers: + if buf in self.buffers_matrix(): + result.append(buf + "_ld") + for buf in buffers: + if buf in self.buffers_vector(): + result.append(buf + "_inc = 1") + if self.batched == 2: # strided batched + for buf in buffers: + if buf in self.buffers_matrix(): + result.append(buf + "_stride") + for scalar in self.scalars: + if scalar != "": + default = "1.0" if scalar == "alpha" else "0.0" + result.append(scalar + " = " + default) + else: # batched but not strided-batched + for scalar in self.scalars: + result.append(scalar + "s") + for buf in buffers: + if buf in self.buffers_matrix(): + result.append(buf + "_ld") + for buf in buffers: + if buf in self.buffers_vector() + self.buffers_matrix(): + result.append(buf + "_offsets") + for buf in buffers: + if buf in self.buffers_vector(): + result.append(buf + "_inc = 1") for option in self.options: if option == "a_transpose": result.append("a_transp = False") @@ -849,8 +869,9 @@ class Routine: result.append("lower_triangle = False") if option == "diagonal": result.append("unit_diagonal = False") - for buf in buffers: - result.append(buf + "_offset = 0") + if self.batched != 1: + for buf in buffers: + result.append(buf + "_offset = 0") return result def requirements_doc(self): -- cgit v1.2.3