summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-02-25 14:51:58 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-02-25 14:51:58 +0100
commit6710c609354958e81be422480a996ef6348b749a (patch)
tree94be209a04a1521963c81776a2bb3eb55fa12d7d /scripts
parent9699169cdf019d30dbd6a853a31d8c445804ab54 (diff)
Some style improvements in the pyclblast code generator
Diffstat (limited to 'scripts')
-rw-r--r--scripts/generator/generator/pyclblast.py63
1 files changed, 32 insertions, 31 deletions
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
index 85bcc93f..8075d209 100644
--- a/scripts/generator/generator/pyclblast.py
+++ b/scripts/generator/generator/pyclblast.py
@@ -5,20 +5,20 @@
# Author(s):
# Cedric Nugteren <www.cedricnugteren.nl>
-NL = "\n"
+import os
+
+
+NL = os.linesep
SEPARATOR = "####################################################################################################"
def to_np_dtype(flavour):
- if flavour.precision_name == "S":
- return "float32"
- if flavour.precision_name == "D":
- return "float64"
- if flavour.precision_name == "C":
- return "complex64"
- if flavour.precision_name == "Z":
- return "complex128"
- raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name)
+ return {
+ "S": "float32",
+ "D": "float64",
+ "C": "complex64",
+ "Z": "complex128",
+ }[flavour.precision_name]
def scalar_cython_conversion(scalar, flavour):
@@ -37,6 +37,7 @@ def scalar_cython_conversion(scalar, flavour):
def generate_pyx(routine):
result = ""
if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3"]:
+ indent = " "
result += SEPARATOR + NL
result += "# " + routine.description + ": " + routine.short_names() + NL
@@ -47,7 +48,7 @@ def generate_pyx(routine):
np_dtypes = []
for flavour in routine.flavours:
if flavour.precision_name in ["S", "D", "C", "Z"]:
- result += " CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "("
+ result += indent + "CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "("
result += ", ".join(routine.arguments_def_c(flavour)) + ","
result += "cl_command_queue* queue, cl_event* event)" + NL
np_dtypes.append(to_np_dtype(flavour))
@@ -56,39 +57,39 @@ def generate_pyx(routine):
buffers = routine.inputs[:] + routine.outputs[:]
result += "def " + routine.plain_name() + "(queue, "
result += ", ".join(routine.arguments_python()) + "):" + NL
- result += " dtype = check_dtype([" + ", ".join(buffers) + "], "
+ result += indent + "dtype = check_dtype([" + ", ".join(buffers) + "], "
result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL
for buf in buffers:
if buf in routine.buffers_vector():
- result += " check_vector("
+ result += indent + "check_vector("
else:
- result += " check_matrix("
+ result += indent + "check_matrix("
result += buf + ", \"" + buf + "\")" + NL
result += "" + NL
for buf in buffers:
- result += " cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL
+ result += indent + "cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL
result += "" + NL
- result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL
- result += " cdef cl_event event = NULL" + NL
+ result += indent + "cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL
+ result += indent + "cdef cl_event event = NULL" + NL
for option in routine.options:
if option == "a_transpose":
- result += " a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL
+ result += indent + "a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL
if option == "b_transpose":
- result += " b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL
+ result += indent + "b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL
if option == "ab_transpose":
- result += " ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL
+ result += indent + "ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL
if option == "side":
- result += " side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL
+ result += indent + "side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL
if option == "triangle":
- result += " triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL
+ result += indent + "triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL
if option == "diagonal":
- result += " diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL
+ result += indent + "diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL
result += "" + NL
- result += " cdef CLBlastStatusCode err" + NL
+ result += indent + "cdef CLBlastStatusCode err" + NL
if_prefix = ""
for flavour in routine.flavours:
if flavour.precision_name in ["S", "D", "C", "Z"]:
@@ -98,16 +99,16 @@ def generate_pyx(routine):
replace("alpha", scalar_cython_conversion("alpha", flavour)).
replace("beta", scalar_cython_conversion("beta", flavour))
for x in routine.arguments()]
- result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
- result += " err = CLBlast" + flavour.name + routine.plain_name()
+ result += indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
+ result += indent + indent + "err = CLBlast" + flavour.name + routine.plain_name()
result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL
if_prefix = "el"
- result += " else:" + NL
- result += " raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL
- result += " if err != CLBlastSuccess:" + NL
- result += " raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL
- result += " return cl.Event.from_int_ptr(<size_t>event)" + NL
+ result += indent + "else:" + NL
+ result += indent + indent + "raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL
+ result += indent + "if err != CLBlastSuccess:" + NL
+ result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL
+ result += indent + "return cl.Event.from_int_ptr(<size_t>event)" + NL
result += NL
return result