diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-02-25 14:51:58 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-02-25 14:51:58 +0100 |
commit | 6710c609354958e81be422480a996ef6348b749a (patch) | |
tree | 94be209a04a1521963c81776a2bb3eb55fa12d7d /scripts/generator | |
parent | 9699169cdf019d30dbd6a853a31d8c445804ab54 (diff) |
Some style improvements in the pyclblast code generator
Diffstat (limited to 'scripts/generator')
-rw-r--r-- | scripts/generator/generator/pyclblast.py | 63 |
1 files changed, 32 insertions, 31 deletions
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py index 85bcc93f..8075d209 100644 --- a/scripts/generator/generator/pyclblast.py +++ b/scripts/generator/generator/pyclblast.py @@ -5,20 +5,20 @@ # Author(s): # Cedric Nugteren <www.cedricnugteren.nl> -NL = "\n" +import os + + +NL = os.linesep SEPARATOR = "####################################################################################################" def to_np_dtype(flavour): - if flavour.precision_name == "S": - return "float32" - if flavour.precision_name == "D": - return "float64" - if flavour.precision_name == "C": - return "complex64" - if flavour.precision_name == "Z": - return "complex128" - raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name) + return { + "S": "float32", + "D": "float64", + "C": "complex64", + "Z": "complex128", + }[flavour.precision_name] def scalar_cython_conversion(scalar, flavour): @@ -37,6 +37,7 @@ def scalar_cython_conversion(scalar, flavour): def generate_pyx(routine): result = "" if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3"]: + indent = " " result += SEPARATOR + NL result += "# " + routine.description + ": " + routine.short_names() + NL @@ -47,7 +48,7 @@ def generate_pyx(routine): np_dtypes = [] for flavour in routine.flavours: if flavour.precision_name in ["S", "D", "C", "Z"]: - result += " CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "(" + result += indent + "CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "(" result += ", ".join(routine.arguments_def_c(flavour)) + "," result += "cl_command_queue* queue, cl_event* event)" + NL np_dtypes.append(to_np_dtype(flavour)) @@ -56,39 +57,39 @@ def generate_pyx(routine): buffers = routine.inputs[:] + routine.outputs[:] result += "def " + routine.plain_name() + "(queue, " result += ", ".join(routine.arguments_python()) + "):" + NL - result += " dtype = check_dtype([" + ", ".join(buffers) + "], " + result += indent + "dtype = check_dtype([" + ", ".join(buffers) + "], " result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL for buf in buffers: if buf in routine.buffers_vector(): - result += " check_vector(" + result += indent + "check_vector(" else: - result += " check_matrix(" + result += indent + "check_matrix(" result += buf + ", \"" + buf + "\")" + NL result += "" + NL for buf in buffers: - result += " cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL + result += indent + "cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL result += "" + NL - result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL - result += " cdef cl_event event = NULL" + NL + result += indent + "cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL + result += indent + "cdef cl_event event = NULL" + NL for option in routine.options: if option == "a_transpose": - result += " a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL + result += indent + "a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL if option == "b_transpose": - result += " b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL + result += indent + "b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL if option == "ab_transpose": - result += " ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL + result += indent + "ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL if option == "side": - result += " side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL + result += indent + "side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL if option == "triangle": - result += " triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL + result += indent + "triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL if option == "diagonal": - result += " diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL + result += indent + "diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL result += "" + NL - result += " cdef CLBlastStatusCode err" + NL + result += indent + "cdef CLBlastStatusCode err" + NL if_prefix = "" for flavour in routine.flavours: if flavour.precision_name in ["S", "D", "C", "Z"]: @@ -98,16 +99,16 @@ def generate_pyx(routine): replace("alpha", scalar_cython_conversion("alpha", flavour)). replace("beta", scalar_cython_conversion("beta", flavour)) for x in routine.arguments()] - result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL - result += " err = CLBlast" + flavour.name + routine.plain_name() + result += indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL + result += indent + indent + "err = CLBlast" + flavour.name + routine.plain_name() result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL if_prefix = "el" - result += " else:" + NL - result += " raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL - result += " if err != CLBlastSuccess:" + NL - result += " raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL - result += " return cl.Event.from_int_ptr(<size_t>event)" + NL + result += indent + "else:" + NL + result += indent + indent + "raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL + result += indent + "if err != CLBlastSuccess:" + NL + result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL + result += indent + "return cl.Event.from_int_ptr(<size_t>event)" + NL result += NL return result |