summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/generator/generator.py12
-rw-r--r--scripts/generator/generator/pyclblast.py113
-rw-r--r--scripts/generator/generator/routine.py32
3 files changed, 154 insertions, 3 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 955625f5..8c071ab3 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -18,6 +18,7 @@
# clblast_netlib_c.cpp
# wrapper_clblas.h
# wrapper_cblas.h
+# pyclblast.pyx
# It also generates the main functions for the correctness and performance tests as found in
# test/correctness/routines/levelX/xYYYY.cpp
# test/performance/routines/levelX/xYYYY.cpp
@@ -30,6 +31,7 @@ import argparse
import generator.cpp as cpp
import generator.doc as doc
+import generator.pyclblast as pyclblast
from generator.routine import Routine
from generator.datatype import H, S, D, C, Z, Sc, Dz, iH, iS, iD, iC, iZ, Css, Zdd, Ccs, Zzd, T, Tc, TU
@@ -45,9 +47,10 @@ FILES = [
"/src/clblast_netlib_c.cpp",
"/include/clblast_cuda.h",
"/src/clblast_cuda.cpp",
+ "/src/pyclblast/src/pyclblast.pyx"
]
-HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
-FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55]
+HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21, 288]
+FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55, 1]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
@@ -209,7 +212,8 @@ def main(argv):
body = ""
levels = [1, 2, 3] if (i == 4 or i == 5 or i == 6) else [1, 2, 3, 4]
for level in levels:
- body += cpp.LEVEL_SEPARATORS[level - 1] + "\n"
+ if i not in [11]:
+ body += cpp.LEVEL_SEPARATORS[level - 1] + "\n"
for routine in ROUTINES[level - 1]:
if i == 0:
body += cpp.clblast_h(routine)
@@ -235,6 +239,8 @@ def main(argv):
body += cpp.clblast_h(routine, cuda=True)
if i == 10:
body += cpp.clblast_cc(routine, cuda=True)
+ if i == 11:
+ body += pyclblast.generate_pyx(routine)
f.write("".join(file_header))
f.write(body)
f.write("".join(file_footer))
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
new file mode 100644
index 00000000..85bcc93f
--- /dev/null
+++ b/scripts/generator/generator/pyclblast.py
@@ -0,0 +1,113 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+NL = "\n"
+SEPARATOR = "####################################################################################################"
+
+
+def to_np_dtype(flavour):
+ if flavour.precision_name == "S":
+ return "float32"
+ if flavour.precision_name == "D":
+ return "float64"
+ if flavour.precision_name == "C":
+ return "complex64"
+ if flavour.precision_name == "Z":
+ return "complex128"
+ raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name)
+
+
+def scalar_cython_conversion(scalar, flavour):
+ scalar_type = flavour.alpha_cl if scalar == "alpha" else flavour.beta_cl
+ if scalar_type == "float":
+ return "<cl_float>" + scalar
+ if scalar_type == "double":
+ return "<cl_double>" + scalar
+ if scalar_type in ["cl_float2", "float2"]:
+ return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ if scalar_type in ["cl_double2", "double2"]:
+ return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ raise RuntimeError("Could not convert flavour '%s:%s'" % (flavour.precision_name, scalar_type))
+
+
+def generate_pyx(routine):
+ result = ""
+ if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3"]:
+
+ result += SEPARATOR + NL
+ result += "# " + routine.description + ": " + routine.short_names() + NL
+ result += SEPARATOR + NL
+ result += NL
+
+ result += "cdef extern from \"clblast_c.h\":" + NL
+ np_dtypes = []
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ result += " CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "("
+ result += ", ".join(routine.arguments_def_c(flavour)) + ","
+ result += "cl_command_queue* queue, cl_event* event)" + NL
+ np_dtypes.append(to_np_dtype(flavour))
+ result += "" + NL
+
+ buffers = routine.inputs[:] + routine.outputs[:]
+ result += "def " + routine.plain_name() + "(queue, "
+ result += ", ".join(routine.arguments_python()) + "):" + NL
+ result += " dtype = check_dtype([" + ", ".join(buffers) + "], "
+ result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL
+ for buf in buffers:
+ if buf in routine.buffers_vector():
+ result += " check_vector("
+ else:
+ result += " check_matrix("
+ result += buf + ", \"" + buf + "\")" + NL
+ result += "" + NL
+
+ for buf in buffers:
+ result += " cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL
+ result += "" + NL
+
+ result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL
+ result += " cdef cl_event event = NULL" + NL
+
+ for option in routine.options:
+ if option == "a_transpose":
+ result += " a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL
+ if option == "b_transpose":
+ result += " b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL
+ if option == "ab_transpose":
+ result += " ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL
+ if option == "side":
+ result += " side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL
+ if option == "triangle":
+ result += " triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL
+ if option == "diagonal":
+ result += " diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL
+
+ result += "" + NL
+ result += " cdef CLBlastStatusCode err" + NL
+ if_prefix = ""
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ np_dtype = to_np_dtype(flavour)
+ argument_names = [x.
+ replace("layout", "CLBlastLayoutRowMajor").
+ replace("alpha", scalar_cython_conversion("alpha", flavour)).
+ replace("beta", scalar_cython_conversion("beta", flavour))
+ for x in routine.arguments()]
+ result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
+ result += " err = CLBlast" + flavour.name + routine.plain_name()
+ result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL
+ if_prefix = "el"
+
+ result += " else:" + NL
+ result += " raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL
+ result += " if err != CLBlastSuccess:" + NL
+ result += " raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL
+ result += " return cl.Event.from_int_ptr(<size_t>event)" + NL
+ result += NL
+
+ return result
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 052709ee..c52f49ca 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -815,6 +815,38 @@ class Routine:
list(chain(*[self.scalar_doc(s) for s in self.other_scalars()])) +
self.batch_count_doc())
+ def arguments_python(self):
+ """Arguments for the Python wrapper pyclblast"""
+ result = list()
+ result.extend(self.sizes)
+ buffers = self.inputs + self.outputs
+ result.extend(buffers[:])
+ for buf in buffers:
+ if buf in self.buffers_matrix():
+ result.append(buf + "_ld")
+ for buf in buffers:
+ if buf in self.buffers_vector():
+ result.append(buf + "_inc = 1")
+ for scalar in self.scalars:
+ default = "1.0" if scalar == "alpha" else "0.0"
+ result.append(scalar + " = " + default)
+ for option in self.options:
+ if option == "a_transpose":
+ result.append("a_transp = False")
+ if option == "b_transpose":
+ result.append("b_transp = False")
+ if option == "ab_transpose":
+ result.append("ab_transp = False")
+ if option == "side":
+ result.append("right_side = False")
+ if option == "triangle":
+ result.append("lower_triangle = False")
+ if option == "diagonal":
+ result.append("unit_diagonal = False")
+ for buf in buffers:
+ result.append(buf + "_offset = 0")
+ return result
+
def requirements_doc(self):
"""Retrieves a list of routine requirements for documentation"""
return self.requirements