summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--scripts/generator/generator/convert.py13
-rw-r--r--scripts/generator/generator/pyclblast.py37
-rw-r--r--scripts/generator/generator/routine.py17
-rw-r--r--src/pyclblast/pyclblast/pyclblast.pyx40
4 files changed, 89 insertions, 18 deletions
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py
index 44eb69d6..07f45669 100644
--- a/scripts/generator/generator/convert.py
+++ b/scripts/generator/generator/convert.py
@@ -80,16 +80,3 @@ def option_to_documentation(x):
'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.",
}[x]
-
-
-def option_to_clblastdefault(x):
- """Translates an option name to a CLBlast C default type"""
- return {
- 'layout': "CLBlastLayoutColMajor",
- 'a_transpose': "CLBlastTransposeNo",
- 'b_transpose': "CLBlastTransposeNo",
- 'ab_transpose': "CLBlastTransposeNo",
- 'side': "CLBlastSideLeft",
- 'triangle': "CLBlastTriangleUpper",
- 'diagonal': "CLBlastDiagonalNonUnit",
- }[x]
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
index 089a410a..ffeaab8d 100644
--- a/scripts/generator/generator/pyclblast.py
+++ b/scripts/generator/generator/pyclblast.py
@@ -21,9 +21,21 @@ def to_np_dtype(flavour):
raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name)
+def scalar_cython_conversion(scalar, flavour):
+ if flavour.precision_name == "S":
+ return "<cl_float>" + scalar
+ if flavour.precision_name == "D":
+ return "<cl_double>" + scalar
+ if flavour.precision_name == "C":
+ return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ if flavour.precision_name == "Z":
+ return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ raise RuntimeError("Could not convert flavour '%s'" % flavour.precision_name)
+
+
def generate_pyx(routine):
result = ""
- if routine.implemented and routine.plain_name() == "swap": # TODO: Generalize
+ if routine.implemented and routine.plain_name() in ["swap", "gemm"]: # TODO: Generalize
result += SEPARATOR + NL
result += "# " + routine.description + ": " + routine.short_names() + NL
@@ -59,16 +71,35 @@ def generate_pyx(routine):
result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL
result += " cdef cl_event event = NULL" + NL
- result += "" + NL
+ for option in routine.options:
+ if option == "a_transpose":
+ result += " a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL
+ if option == "b_transpose":
+ result += " b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL
+ if option == "ab_transpose":
+ result += " ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL
+ if option == "side":
+ result += " side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL
+ if option == "triangle":
+ result += " triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL
+ if option == "diagonal":
+ result += " diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL
+
+ result += "" + NL
result += " cdef CLBlastStatusCode err" + NL
if_prefix = ""
for flavour in routine.flavours:
if flavour.precision_name in ["S", "D", "C", "Z"]:
np_dtype = to_np_dtype(flavour)
+ argument_names = [x.
+ replace("layout", "CLBlastLayoutRowMajor").
+ replace("alpha", scalar_cython_conversion("alpha", flavour)).
+ replace("beta", scalar_cython_conversion("beta", flavour))
+ for x in routine.arguments()]
result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
result += " err = CLBlast" + flavour.name + routine.plain_name()
- result += "(" + ", ".join(routine.arguments()) + ", &command_queue, &event)" + NL
+ result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL
if_prefix = "el"
result += " else:" + NL
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index d0b0a6d7..c52f49ca 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -827,9 +827,22 @@ class Routine:
for buf in buffers:
if buf in self.buffers_vector():
result.append(buf + "_inc = 1")
+ for scalar in self.scalars:
+ default = "1.0" if scalar == "alpha" else "0.0"
+ result.append(scalar + " = " + default)
for option in self.options:
- default = convert.option_to_clblastdefault(option)
- result.append(option + " = " + default)
+ if option == "a_transpose":
+ result.append("a_transp = False")
+ if option == "b_transpose":
+ result.append("b_transp = False")
+ if option == "ab_transpose":
+ result.append("ab_transp = False")
+ if option == "side":
+ result.append("right_side = False")
+ if option == "triangle":
+ result.append("lower_triangle = False")
+ if option == "diagonal":
+ result.append("unit_diagonal = False")
for buf in buffers:
result.append(buf + "_offset = 0")
return result
diff --git a/src/pyclblast/pyclblast/pyclblast.pyx b/src/pyclblast/pyclblast/pyclblast.pyx
index a090d367..2f6ebba2 100644
--- a/src/pyclblast/pyclblast/pyclblast.pyx
+++ b/src/pyclblast/pyclblast/pyclblast.pyx
@@ -323,3 +323,43 @@ def swap(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0):
return cl.Event.from_int_ptr(<size_t>event)
####################################################################################################
+# General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM
+####################################################################################################
+
+cdef extern from "clblast_c.h":
+ CLBlastStatusCode CLBlastSgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const float alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const float beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastDgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const double alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const double beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastCgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_float2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const cl_float2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastZgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_double2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event)
+
+def gemm(queue, m, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, a_transp = False, b_transp = False, a_offset = 0, b_offset = 0, c_offset = 0):
+ dtype = check_dtype([a, b, c], ["float32", "float64", "complex64", "complex128"])
+ check_matrix(a, "a")
+ check_matrix(b, "b")
+ check_matrix(c, "c")
+
+ cdef cl_mem a_buffer = <cl_mem><size_t>a.base_data.int_ptr
+ cdef cl_mem b_buffer = <cl_mem><size_t>b.base_data.int_ptr
+ cdef cl_mem c_buffer = <cl_mem><size_t>c.base_data.int_ptr
+
+ cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr
+ cdef cl_event event = NULL
+ a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo
+ b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo
+
+ cdef CLBlastStatusCode err
+ if dtype == np.dtype("float32"):
+ err = CLBlastSgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_float>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
+ elif dtype == np.dtype("float64"):
+ err = CLBlastDgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
+ elif dtype == np.dtype("complex64"):
+ err = CLBlastCgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float2>cl_float2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_float2>cl_float2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
+ elif dtype == np.dtype("complex128"):
+ err = CLBlastZgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
+ else:
+ raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+ if err != CLBlastSuccess:
+ raise RuntimeError("PyCLBlast: 'CLBlastXgemm' failed: %s" % get_status_message(err))
+ return cl.Event.from_int_ptr(<size_t>event)
+
+####################################################################################################