diff options
-rw-r--r-- | scripts/generator/generator/convert.py | 13 | ||||
-rw-r--r-- | scripts/generator/generator/pyclblast.py | 37 | ||||
-rw-r--r-- | scripts/generator/generator/routine.py | 17 | ||||
-rw-r--r-- | src/pyclblast/pyclblast/pyclblast.pyx | 40 |
4 files changed, 89 insertions, 18 deletions
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py index 44eb69d6..07f45669 100644 --- a/scripts/generator/generator/convert.py +++ b/scripts/generator/generator/convert.py @@ -80,16 +80,3 @@ def option_to_documentation(x): 'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).", 'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.", }[x] - - -def option_to_clblastdefault(x): - """Translates an option name to a CLBlast C default type""" - return { - 'layout': "CLBlastLayoutColMajor", - 'a_transpose': "CLBlastTransposeNo", - 'b_transpose': "CLBlastTransposeNo", - 'ab_transpose': "CLBlastTransposeNo", - 'side': "CLBlastSideLeft", - 'triangle': "CLBlastTriangleUpper", - 'diagonal': "CLBlastDiagonalNonUnit", - }[x] diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py index 089a410a..ffeaab8d 100644 --- a/scripts/generator/generator/pyclblast.py +++ b/scripts/generator/generator/pyclblast.py @@ -21,9 +21,21 @@ def to_np_dtype(flavour): raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name) +def scalar_cython_conversion(scalar, flavour): + if flavour.precision_name == "S": + return "<cl_float>" + scalar + if flavour.precision_name == "D": + return "<cl_double>" + scalar + if flavour.precision_name == "C": + return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)" + if flavour.precision_name == "Z": + return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)" + raise RuntimeError("Could not convert flavour '%s'" % flavour.precision_name) + + def generate_pyx(routine): result = "" - if routine.implemented and routine.plain_name() == "swap": # TODO: Generalize + if routine.implemented and routine.plain_name() in ["swap", "gemm"]: # TODO: Generalize result += SEPARATOR + NL result += "# " + routine.description + ": " + routine.short_names() + NL @@ -59,16 +71,35 @@ def generate_pyx(routine): result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL result += " cdef cl_event event = NULL" + NL - result += "" + NL + for option in routine.options: + if option == "a_transpose": + result += " a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL + if option == "b_transpose": + result += " b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL + if option == "ab_transpose": + result += " ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL + if option == "side": + result += " side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL + if option == "triangle": + result += " triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL + if option == "diagonal": + result += " diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL + + result += "" + NL result += " cdef CLBlastStatusCode err" + NL if_prefix = "" for flavour in routine.flavours: if flavour.precision_name in ["S", "D", "C", "Z"]: np_dtype = to_np_dtype(flavour) + argument_names = [x. + replace("layout", "CLBlastLayoutRowMajor"). + replace("alpha", scalar_cython_conversion("alpha", flavour)). + replace("beta", scalar_cython_conversion("beta", flavour)) + for x in routine.arguments()] result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL result += " err = CLBlast" + flavour.name + routine.plain_name() - result += "(" + ", ".join(routine.arguments()) + ", &command_queue, &event)" + NL + result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL if_prefix = "el" result += " else:" + NL diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index d0b0a6d7..c52f49ca 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -827,9 +827,22 @@ class Routine: for buf in buffers: if buf in self.buffers_vector(): result.append(buf + "_inc = 1") + for scalar in self.scalars: + default = "1.0" if scalar == "alpha" else "0.0" + result.append(scalar + " = " + default) for option in self.options: - default = convert.option_to_clblastdefault(option) - result.append(option + " = " + default) + if option == "a_transpose": + result.append("a_transp = False") + if option == "b_transpose": + result.append("b_transp = False") + if option == "ab_transpose": + result.append("ab_transp = False") + if option == "side": + result.append("right_side = False") + if option == "triangle": + result.append("lower_triangle = False") + if option == "diagonal": + result.append("unit_diagonal = False") for buf in buffers: result.append(buf + "_offset = 0") return result diff --git a/src/pyclblast/pyclblast/pyclblast.pyx b/src/pyclblast/pyclblast/pyclblast.pyx index a090d367..2f6ebba2 100644 --- a/src/pyclblast/pyclblast/pyclblast.pyx +++ b/src/pyclblast/pyclblast/pyclblast.pyx @@ -323,3 +323,43 @@ def swap(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0): return cl.Event.from_int_ptr(<size_t>event) #################################################################################################### +# General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM +#################################################################################################### + +cdef extern from "clblast_c.h": + CLBlastStatusCode CLBlastSgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const float alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const float beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastDgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const double alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const double beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastCgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_float2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const cl_float2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastZgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_double2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event) + +def gemm(queue, m, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, a_transp = False, b_transp = False, a_offset = 0, b_offset = 0, c_offset = 0): + dtype = check_dtype([a, b, c], ["float32", "float64", "complex64", "complex128"]) + check_matrix(a, "a") + check_matrix(b, "b") + check_matrix(c, "c") + + cdef cl_mem a_buffer = <cl_mem><size_t>a.base_data.int_ptr + cdef cl_mem b_buffer = <cl_mem><size_t>b.base_data.int_ptr + cdef cl_mem c_buffer = <cl_mem><size_t>c.base_data.int_ptr + + cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr + cdef cl_event event = NULL + a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo + b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo + + cdef CLBlastStatusCode err + if dtype == np.dtype("float32"): + err = CLBlastSgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_float>beta, c_buffer, c_offset, c_ld, &command_queue, &event) + elif dtype == np.dtype("float64"): + err = CLBlastDgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double>beta, c_buffer, c_offset, c_ld, &command_queue, &event) + elif dtype == np.dtype("complex64"): + err = CLBlastCgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float2>cl_float2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_float2>cl_float2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event) + elif dtype == np.dtype("complex128"): + err = CLBlastZgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event) + else: + raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: + raise RuntimeError("PyCLBlast: 'CLBlastXgemm' failed: %s" % get_status_message(err)) + return cl.Event.from_int_ptr(<size_t>event) + +#################################################################################################### |