summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-02-14 20:50:47 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-02-14 20:50:47 +0100
commiteb85f6b514b285d7dde1ac02b97b7581a46ff21d (patch)
treed32b88f353162d53fff1d82a82890c4080a9565a /scripts
parent61b8c771ed906720459b029d91f97c7df0785938 (diff)
First agenerated version (clblastXswap only for now) of the pyclblast wrapper
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/generator/generator.py12
-rw-r--r--scripts/generator/generator/convert.py13
-rw-r--r--scripts/generator/generator/pyclblast.py81
-rw-r--r--scripts/generator/generator/routine.py19
4 files changed, 122 insertions, 3 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 955625f5..c25d0e4f 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -18,6 +18,7 @@
# clblast_netlib_c.cpp
# wrapper_clblas.h
# wrapper_cblas.h
+# pyclblast.pyx
# It also generates the main functions for the correctness and performance tests as found in
# test/correctness/routines/levelX/xYYYY.cpp
# test/performance/routines/levelX/xYYYY.cpp
@@ -30,6 +31,7 @@ import argparse
import generator.cpp as cpp
import generator.doc as doc
+import generator.pyclblast as pyclblast
from generator.routine import Routine
from generator.datatype import H, S, D, C, Z, Sc, Dz, iH, iS, iD, iC, iZ, Css, Zdd, Ccs, Zzd, T, Tc, TU
@@ -45,9 +47,10 @@ FILES = [
"/src/clblast_netlib_c.cpp",
"/include/clblast_cuda.h",
"/src/clblast_cuda.cpp",
+ "/src/pyclblast/pyclblast/pyclblast.pyx"
]
-HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
-FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55]
+HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21, 288]
+FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55, 1]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
@@ -209,7 +212,8 @@ def main(argv):
body = ""
levels = [1, 2, 3] if (i == 4 or i == 5 or i == 6) else [1, 2, 3, 4]
for level in levels:
- body += cpp.LEVEL_SEPARATORS[level - 1] + "\n"
+ if i not in [11]:
+ body += cpp.LEVEL_SEPARATORS[level - 1] + "\n"
for routine in ROUTINES[level - 1]:
if i == 0:
body += cpp.clblast_h(routine)
@@ -235,6 +239,8 @@ def main(argv):
body += cpp.clblast_h(routine, cuda=True)
if i == 10:
body += cpp.clblast_cc(routine, cuda=True)
+ if i == 11:
+ body += pyclblast.generate_pyx(routine)
f.write("".join(file_header))
f.write(body)
f.write("".join(file_footer))
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py
index 07f45669..44eb69d6 100644
--- a/scripts/generator/generator/convert.py
+++ b/scripts/generator/generator/convert.py
@@ -80,3 +80,16 @@ def option_to_documentation(x):
'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.",
}[x]
+
+
+def option_to_clblastdefault(x):
+ """Translates an option name to a CLBlast C default type"""
+ return {
+ 'layout': "CLBlastLayoutColMajor",
+ 'a_transpose': "CLBlastTransposeNo",
+ 'b_transpose': "CLBlastTransposeNo",
+ 'ab_transpose': "CLBlastTransposeNo",
+ 'side': "CLBlastSideLeft",
+ 'triangle': "CLBlastTriangleUpper",
+ 'diagonal': "CLBlastDiagonalNonUnit",
+ }[x]
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
new file mode 100644
index 00000000..089a410a
--- /dev/null
+++ b/scripts/generator/generator/pyclblast.py
@@ -0,0 +1,81 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+NL = "\n"
+SEPARATOR = "####################################################################################################"
+
+
+def to_np_dtype(flavour):
+ if flavour.precision_name == "S":
+ return "float32"
+ if flavour.precision_name == "D":
+ return "float64"
+ if flavour.precision_name == "C":
+ return "complex64"
+ if flavour.precision_name == "Z":
+ return "complex128"
+ raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name)
+
+
+def generate_pyx(routine):
+ result = ""
+ if routine.implemented and routine.plain_name() == "swap": # TODO: Generalize
+
+ result += SEPARATOR + NL
+ result += "# " + routine.description + ": " + routine.short_names() + NL
+ result += SEPARATOR + NL
+ result += NL
+
+ result += "cdef extern from \"clblast_c.h\":" + NL
+ np_dtypes = []
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ result += " CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "("
+ result += ", ".join(routine.arguments_def_c(flavour)) + ","
+ result += "cl_command_queue* queue, cl_event* event)" + NL
+ np_dtypes.append(to_np_dtype(flavour))
+ result += "" + NL
+
+ buffers = routine.inputs[:] + routine.outputs[:]
+ result += "def " + routine.plain_name() + "(queue, "
+ result += ", ".join(routine.arguments_python()) + "):" + NL
+ result += " dtype = check_dtype([" + ", ".join(buffers) + "], "
+ result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL
+ for buf in buffers:
+ if buf in routine.buffers_vector():
+ result += " check_vector("
+ else:
+ result += " check_matrix("
+ result += buf + ", \"" + buf + "\")" + NL
+ result += "" + NL
+
+ for buf in buffers:
+ result += " cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL
+ result += "" + NL
+
+ result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL
+ result += " cdef cl_event event = NULL" + NL
+ result += "" + NL
+
+ result += " cdef CLBlastStatusCode err" + NL
+ if_prefix = ""
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ np_dtype = to_np_dtype(flavour)
+ result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
+ result += " err = CLBlast" + flavour.name + routine.plain_name()
+ result += "(" + ", ".join(routine.arguments()) + ", &command_queue, &event)" + NL
+ if_prefix = "el"
+
+ result += " else:" + NL
+ result += " raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL
+ result += " if err != CLBlastSuccess:" + NL
+ result += " raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL
+ result += " return cl.Event.from_int_ptr(<size_t>event)" + NL
+ result += NL
+
+ return result
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 052709ee..d0b0a6d7 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -815,6 +815,25 @@ class Routine:
list(chain(*[self.scalar_doc(s) for s in self.other_scalars()])) +
self.batch_count_doc())
+ def arguments_python(self):
+ """Arguments for the Python wrapper pyclblast"""
+ result = list()
+ result.extend(self.sizes)
+ buffers = self.inputs + self.outputs
+ result.extend(buffers[:])
+ for buf in buffers:
+ if buf in self.buffers_matrix():
+ result.append(buf + "_ld")
+ for buf in buffers:
+ if buf in self.buffers_vector():
+ result.append(buf + "_inc = 1")
+ for option in self.options:
+ default = convert.option_to_clblastdefault(option)
+ result.append(option + " = " + default)
+ for buf in buffers:
+ result.append(buf + "_offset = 0")
+ return result
+
def requirements_doc(self):
"""Retrieves a list of routine requirements for documentation"""
return self.requirements