summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-02-14 20:50:47 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-02-14 20:50:47 +0100
commiteb85f6b514b285d7dde1ac02b97b7581a46ff21d (patch)
treed32b88f353162d53fff1d82a82890c4080a9565a
parent61b8c771ed906720459b029d91f97c7df0785938 (diff)
First agenerated version (clblastXswap only for now) of the pyclblast wrapper
-rwxr-xr-xscripts/generator/generator.py12
-rw-r--r--scripts/generator/generator/convert.py13
-rw-r--r--scripts/generator/generator/pyclblast.py81
-rw-r--r--scripts/generator/generator/routine.py19
-rw-r--r--src/pyclblast/pyclblast/pyclblast.pyx78
5 files changed, 136 insertions, 67 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 955625f5..c25d0e4f 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -18,6 +18,7 @@
# clblast_netlib_c.cpp
# wrapper_clblas.h
# wrapper_cblas.h
+# pyclblast.pyx
# It also generates the main functions for the correctness and performance tests as found in
# test/correctness/routines/levelX/xYYYY.cpp
# test/performance/routines/levelX/xYYYY.cpp
@@ -30,6 +31,7 @@ import argparse
import generator.cpp as cpp
import generator.doc as doc
+import generator.pyclblast as pyclblast
from generator.routine import Routine
from generator.datatype import H, S, D, C, Z, Sc, Dz, iH, iS, iD, iC, iZ, Css, Zdd, Ccs, Zzd, T, Tc, TU
@@ -45,9 +47,10 @@ FILES = [
"/src/clblast_netlib_c.cpp",
"/include/clblast_cuda.h",
"/src/clblast_cuda.cpp",
+ "/src/pyclblast/pyclblast/pyclblast.pyx"
]
-HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
-FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55]
+HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21, 288]
+FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55, 1]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
@@ -209,7 +212,8 @@ def main(argv):
body = ""
levels = [1, 2, 3] if (i == 4 or i == 5 or i == 6) else [1, 2, 3, 4]
for level in levels:
- body += cpp.LEVEL_SEPARATORS[level - 1] + "\n"
+ if i not in [11]:
+ body += cpp.LEVEL_SEPARATORS[level - 1] + "\n"
for routine in ROUTINES[level - 1]:
if i == 0:
body += cpp.clblast_h(routine)
@@ -235,6 +239,8 @@ def main(argv):
body += cpp.clblast_h(routine, cuda=True)
if i == 10:
body += cpp.clblast_cc(routine, cuda=True)
+ if i == 11:
+ body += pyclblast.generate_pyx(routine)
f.write("".join(file_header))
f.write(body)
f.write("".join(file_footer))
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py
index 07f45669..44eb69d6 100644
--- a/scripts/generator/generator/convert.py
+++ b/scripts/generator/generator/convert.py
@@ -80,3 +80,16 @@ def option_to_documentation(x):
'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.",
}[x]
+
+
+def option_to_clblastdefault(x):
+ """Translates an option name to a CLBlast C default type"""
+ return {
+ 'layout': "CLBlastLayoutColMajor",
+ 'a_transpose': "CLBlastTransposeNo",
+ 'b_transpose': "CLBlastTransposeNo",
+ 'ab_transpose': "CLBlastTransposeNo",
+ 'side': "CLBlastSideLeft",
+ 'triangle': "CLBlastTriangleUpper",
+ 'diagonal': "CLBlastDiagonalNonUnit",
+ }[x]
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
new file mode 100644
index 00000000..089a410a
--- /dev/null
+++ b/scripts/generator/generator/pyclblast.py
@@ -0,0 +1,81 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+NL = "\n"
+SEPARATOR = "####################################################################################################"
+
+
+def to_np_dtype(flavour):
+ if flavour.precision_name == "S":
+ return "float32"
+ if flavour.precision_name == "D":
+ return "float64"
+ if flavour.precision_name == "C":
+ return "complex64"
+ if flavour.precision_name == "Z":
+ return "complex128"
+ raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name)
+
+
+def generate_pyx(routine):
+ result = ""
+ if routine.implemented and routine.plain_name() == "swap": # TODO: Generalize
+
+ result += SEPARATOR + NL
+ result += "# " + routine.description + ": " + routine.short_names() + NL
+ result += SEPARATOR + NL
+ result += NL
+
+ result += "cdef extern from \"clblast_c.h\":" + NL
+ np_dtypes = []
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ result += " CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "("
+ result += ", ".join(routine.arguments_def_c(flavour)) + ","
+ result += "cl_command_queue* queue, cl_event* event)" + NL
+ np_dtypes.append(to_np_dtype(flavour))
+ result += "" + NL
+
+ buffers = routine.inputs[:] + routine.outputs[:]
+ result += "def " + routine.plain_name() + "(queue, "
+ result += ", ".join(routine.arguments_python()) + "):" + NL
+ result += " dtype = check_dtype([" + ", ".join(buffers) + "], "
+ result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL
+ for buf in buffers:
+ if buf in routine.buffers_vector():
+ result += " check_vector("
+ else:
+ result += " check_matrix("
+ result += buf + ", \"" + buf + "\")" + NL
+ result += "" + NL
+
+ for buf in buffers:
+ result += " cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL
+ result += "" + NL
+
+ result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL
+ result += " cdef cl_event event = NULL" + NL
+ result += "" + NL
+
+ result += " cdef CLBlastStatusCode err" + NL
+ if_prefix = ""
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ np_dtype = to_np_dtype(flavour)
+ result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
+ result += " err = CLBlast" + flavour.name + routine.plain_name()
+ result += "(" + ", ".join(routine.arguments()) + ", &command_queue, &event)" + NL
+ if_prefix = "el"
+
+ result += " else:" + NL
+ result += " raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL
+ result += " if err != CLBlastSuccess:" + NL
+ result += " raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL
+ result += " return cl.Event.from_int_ptr(<size_t>event)" + NL
+ result += NL
+
+ return result
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 052709ee..d0b0a6d7 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -815,6 +815,25 @@ class Routine:
list(chain(*[self.scalar_doc(s) for s in self.other_scalars()])) +
self.batch_count_doc())
+ def arguments_python(self):
+ """Arguments for the Python wrapper pyclblast"""
+ result = list()
+ result.extend(self.sizes)
+ buffers = self.inputs + self.outputs
+ result.extend(buffers[:])
+ for buf in buffers:
+ if buf in self.buffers_matrix():
+ result.append(buf + "_ld")
+ for buf in buffers:
+ if buf in self.buffers_vector():
+ result.append(buf + "_inc = 1")
+ for option in self.options:
+ default = convert.option_to_clblastdefault(option)
+ result.append(option + " = " + default)
+ for buf in buffers:
+ result.append(buf + "_offset = 0")
+ return result
+
def requirements_doc(self):
"""Retrieves a list of routine requirements for documentation"""
return self.requirements
diff --git a/src/pyclblast/pyclblast/pyclblast.pyx b/src/pyclblast/pyclblast/pyclblast.pyx
index 0cc3b237..a090d367 100644
--- a/src/pyclblast/pyclblast/pyclblast.pyx
+++ b/src/pyclblast/pyclblast/pyclblast.pyx
@@ -1,7 +1,6 @@
####################################################################################################
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0.
-# This file follows uses a max-width of 100 characters per line.
#
# Author(s):
# Cedric Nugteren <www.cedricnugteren.nl>
@@ -287,87 +286,38 @@ def check_vector(a, name):
check_array(a, 1, name)
-def check_shape_dim(shape, dim, target, name):
- if shape[dim] != target:
- raise ValueError("PyCLBlast: '%s.shape[%d]' must be %d (got %d)" % (name, dim, target, shape[dim]))
-
####################################################################################################
# Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP
####################################################################################################
cdef extern from "clblast_c.h":
- CLBlastStatusCode CLBlastSswap(
- const size_t n,
- cl_mem x_buffer,
- const size_t x_offset,
- const size_t x_inc,
- cl_mem y_buffer,
- const size_t y_offset,
- const size_t y_inc,
- cl_command_queue* queue,
- cl_event* event)
- CLBlastStatusCode CLBlastDswap(
- const size_t n,
- cl_mem x_buffer,
- const size_t x_offset,
- const size_t x_inc,
- cl_mem y_buffer,
- const size_t y_offset,
- const size_t y_inc,
- cl_command_queue* queue,
- cl_event* event)
- CLBlastStatusCode CLBlastCswap(
- const size_t n,
- cl_mem x_buffer,
- const size_t x_offset,
- const size_t x_inc,
- cl_mem y_buffer,
- const size_t y_offset,
- const size_t y_inc,
- cl_command_queue* queue,
- cl_event* event)
- CLBlastStatusCode CLBlastZswap(
- const size_t n,
- cl_mem x_buffer,
- const size_t x_offset,
- const size_t x_inc,
- cl_mem y_buffer,
- const size_t y_offset,
- const size_t y_inc,
- cl_command_queue* queue,
- cl_event* event)
+ CLBlastStatusCode CLBlastSswap(const size_t n, cl_mem x_buffer, const size_t x_offset, const size_t x_inc, cl_mem y_buffer, const size_t y_offset, const size_t y_inc,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastDswap(const size_t n, cl_mem x_buffer, const size_t x_offset, const size_t x_inc, cl_mem y_buffer, const size_t y_offset, const size_t y_inc,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastCswap(const size_t n, cl_mem x_buffer, const size_t x_offset, const size_t x_inc, cl_mem y_buffer, const size_t y_offset, const size_t y_inc,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastZswap(const size_t n, cl_mem x_buffer, const size_t x_offset, const size_t x_inc, cl_mem y_buffer, const size_t y_offset, const size_t y_inc,cl_command_queue* queue, cl_event* event)
-def swap(queue, x, y):
- """y, x = x, y"""
+def swap(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0):
dtype = check_dtype([x, y], ["float32", "float64", "complex64", "complex128"])
check_vector(x, "x")
check_vector(y, "y")
- cdef size_t N = x.shape[0]
- check_shape_dim(y.shape, 0, N, "y")
-
- cdef size_t element_size = dtype_size[dtype]
- cdef cl_mem xdata = <cl_mem><size_t>x.base_data.int_ptr
- cdef size_t offx = x.offset / element_size
- cdef int incx = x.strides[0] / element_size
- cdef cl_mem ydata = <cl_mem><size_t>y.base_data.int_ptr
- cdef size_t offy = y.offset / element_size
- cdef int incy = y.strides[0] / element_size
+ cdef cl_mem x_buffer = <cl_mem><size_t>x.base_data.int_ptr
+ cdef cl_mem y_buffer = <cl_mem><size_t>y.base_data.int_ptr
- cdef cl_command_queue commandQueue = <cl_command_queue><size_t>queue.int_ptr
+ cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr
cdef cl_event event = NULL
- cdef CLBlastStatusCode
+ cdef CLBlastStatusCode err
if dtype == np.dtype("float32"):
- err = CLBlastSswap(N, xdata, offx, incx, ydata, offy, incy, &commandQueue, &event)
+ err = CLBlastSswap(n, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("float64"):
- err = CLBlastDswap(N, xdata, offx, incx, ydata, offy, incy, &commandQueue, &event)
+ err = CLBlastDswap(n, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("complex64"):
- err = CLBlastCswap(N, xdata, offx, incx, ydata, offy, incy, &commandQueue, &event)
+ err = CLBlastCswap(n, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("complex128"):
- err = CLBlastZswap(N, xdata, offx, incx, ydata, offy, incy, &commandQueue, &event)
+ err = CLBlastZswap(n, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
- raise ValueError("PyCLBlast: Unrecognized dtype '%s'" % dtype)
+ raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXswap' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)