diff options
Diffstat (limited to 'scripts')
-rwxr-xr-x | scripts/generator/generator.py | 12 | ||||
-rw-r--r-- | scripts/generator/generator/pyclblast.py | 113 | ||||
-rw-r--r-- | scripts/generator/generator/routine.py | 32 |
3 files changed, 154 insertions, 3 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 955625f5..8c071ab3 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -18,6 +18,7 @@ # clblast_netlib_c.cpp # wrapper_clblas.h # wrapper_cblas.h +# pyclblast.pyx # It also generates the main functions for the correctness and performance tests as found in # test/correctness/routines/levelX/xYYYY.cpp # test/performance/routines/levelX/xYYYY.cpp @@ -30,6 +31,7 @@ import argparse import generator.cpp as cpp import generator.doc as doc +import generator.pyclblast as pyclblast from generator.routine import Routine from generator.datatype import H, S, D, C, Z, Sc, Dz, iH, iS, iD, iC, iZ, Css, Zdd, Ccs, Zzd, T, Tc, TU @@ -45,9 +47,10 @@ FILES = [ "/src/clblast_netlib_c.cpp", "/include/clblast_cuda.h", "/src/clblast_cuda.cpp", + "/src/pyclblast/src/pyclblast.pyx" ] -HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21] -FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55] +HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21, 288] +FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55, 1] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 63 @@ -209,7 +212,8 @@ def main(argv): body = "" levels = [1, 2, 3] if (i == 4 or i == 5 or i == 6) else [1, 2, 3, 4] for level in levels: - body += cpp.LEVEL_SEPARATORS[level - 1] + "\n" + if i not in [11]: + body += cpp.LEVEL_SEPARATORS[level - 1] + "\n" for routine in ROUTINES[level - 1]: if i == 0: body += cpp.clblast_h(routine) @@ -235,6 +239,8 @@ def main(argv): body += cpp.clblast_h(routine, cuda=True) if i == 10: body += cpp.clblast_cc(routine, cuda=True) + if i == 11: + body += pyclblast.generate_pyx(routine) f.write("".join(file_header)) f.write(body) f.write("".join(file_footer)) diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py new file mode 100644 index 00000000..85bcc93f --- /dev/null +++ b/scripts/generator/generator/pyclblast.py @@ -0,0 +1,113 @@ + +# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the +# PEP8 Python style guide and uses a max-width of 120 characters per line. +# +# Author(s): +# Cedric Nugteren <www.cedricnugteren.nl> + +NL = "\n" +SEPARATOR = "####################################################################################################" + + +def to_np_dtype(flavour): + if flavour.precision_name == "S": + return "float32" + if flavour.precision_name == "D": + return "float64" + if flavour.precision_name == "C": + return "complex64" + if flavour.precision_name == "Z": + return "complex128" + raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name) + + +def scalar_cython_conversion(scalar, flavour): + scalar_type = flavour.alpha_cl if scalar == "alpha" else flavour.beta_cl + if scalar_type == "float": + return "<cl_float>" + scalar + if scalar_type == "double": + return "<cl_double>" + scalar + if scalar_type in ["cl_float2", "float2"]: + return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)" + if scalar_type in ["cl_double2", "double2"]: + return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)" + raise RuntimeError("Could not convert flavour '%s:%s'" % (flavour.precision_name, scalar_type)) + + +def generate_pyx(routine): + result = "" + if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3"]: + + result += SEPARATOR + NL + result += "# " + routine.description + ": " + routine.short_names() + NL + result += SEPARATOR + NL + result += NL + + result += "cdef extern from \"clblast_c.h\":" + NL + np_dtypes = [] + for flavour in routine.flavours: + if flavour.precision_name in ["S", "D", "C", "Z"]: + result += " CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "(" + result += ", ".join(routine.arguments_def_c(flavour)) + "," + result += "cl_command_queue* queue, cl_event* event)" + NL + np_dtypes.append(to_np_dtype(flavour)) + result += "" + NL + + buffers = routine.inputs[:] + routine.outputs[:] + result += "def " + routine.plain_name() + "(queue, " + result += ", ".join(routine.arguments_python()) + "):" + NL + result += " dtype = check_dtype([" + ", ".join(buffers) + "], " + result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL + for buf in buffers: + if buf in routine.buffers_vector(): + result += " check_vector(" + else: + result += " check_matrix(" + result += buf + ", \"" + buf + "\")" + NL + result += "" + NL + + for buf in buffers: + result += " cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL + result += "" + NL + + result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL + result += " cdef cl_event event = NULL" + NL + + for option in routine.options: + if option == "a_transpose": + result += " a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL + if option == "b_transpose": + result += " b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL + if option == "ab_transpose": + result += " ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL + if option == "side": + result += " side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL + if option == "triangle": + result += " triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL + if option == "diagonal": + result += " diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL + + result += "" + NL + result += " cdef CLBlastStatusCode err" + NL + if_prefix = "" + for flavour in routine.flavours: + if flavour.precision_name in ["S", "D", "C", "Z"]: + np_dtype = to_np_dtype(flavour) + argument_names = [x. + replace("layout", "CLBlastLayoutRowMajor"). + replace("alpha", scalar_cython_conversion("alpha", flavour)). + replace("beta", scalar_cython_conversion("beta", flavour)) + for x in routine.arguments()] + result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL + result += " err = CLBlast" + flavour.name + routine.plain_name() + result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL + if_prefix = "el" + + result += " else:" + NL + result += " raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL + result += " if err != CLBlastSuccess:" + NL + result += " raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL + result += " return cl.Event.from_int_ptr(<size_t>event)" + NL + result += NL + + return result diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 052709ee..c52f49ca 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -815,6 +815,38 @@ class Routine: list(chain(*[self.scalar_doc(s) for s in self.other_scalars()])) + self.batch_count_doc()) + def arguments_python(self): + """Arguments for the Python wrapper pyclblast""" + result = list() + result.extend(self.sizes) + buffers = self.inputs + self.outputs + result.extend(buffers[:]) + for buf in buffers: + if buf in self.buffers_matrix(): + result.append(buf + "_ld") + for buf in buffers: + if buf in self.buffers_vector(): + result.append(buf + "_inc = 1") + for scalar in self.scalars: + default = "1.0" if scalar == "alpha" else "0.0" + result.append(scalar + " = " + default) + for option in self.options: + if option == "a_transpose": + result.append("a_transp = False") + if option == "b_transpose": + result.append("b_transp = False") + if option == "ab_transpose": + result.append("ab_transp = False") + if option == "side": + result.append("right_side = False") + if option == "triangle": + result.append("lower_triangle = False") + if option == "diagonal": + result.append("unit_diagonal = False") + for buf in buffers: + result.append(buf + "_offset = 0") + return result + def requirements_doc(self): """Retrieves a list of routine requirements for documentation""" return self.requirements |