summaryrefslogtreecommitdiff
path: root/src/pyclblast
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-02-18 16:33:20 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-02-18 16:33:20 +0100
commite1bfb4082716ef9619a13e9985aca9ef28cf4cbf (patch)
tree94178e2501add7fd8a3ad683ef3ba2ed1e7cafd8 /src/pyclblast
parenteb85f6b514b285d7dde1ac02b97b7581a46ff21d (diff)
Added GEMM to the Python wrapper
Diffstat (limited to 'src/pyclblast')
-rw-r--r--src/pyclblast/pyclblast/pyclblast.pyx40
1 files changed, 40 insertions, 0 deletions
diff --git a/src/pyclblast/pyclblast/pyclblast.pyx b/src/pyclblast/pyclblast/pyclblast.pyx
index a090d367..2f6ebba2 100644
--- a/src/pyclblast/pyclblast/pyclblast.pyx
+++ b/src/pyclblast/pyclblast/pyclblast.pyx
@@ -323,3 +323,43 @@ def swap(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0):
return cl.Event.from_int_ptr(<size_t>event)
####################################################################################################
+# General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM
+####################################################################################################
+
+cdef extern from "clblast_c.h":
+ CLBlastStatusCode CLBlastSgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const float alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const float beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastDgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const double alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const double beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastCgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_float2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const cl_float2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastZgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_double2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event)
+
+def gemm(queue, m, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, a_transp = False, b_transp = False, a_offset = 0, b_offset = 0, c_offset = 0):
+ dtype = check_dtype([a, b, c], ["float32", "float64", "complex64", "complex128"])
+ check_matrix(a, "a")
+ check_matrix(b, "b")
+ check_matrix(c, "c")
+
+ cdef cl_mem a_buffer = <cl_mem><size_t>a.base_data.int_ptr
+ cdef cl_mem b_buffer = <cl_mem><size_t>b.base_data.int_ptr
+ cdef cl_mem c_buffer = <cl_mem><size_t>c.base_data.int_ptr
+
+ cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr
+ cdef cl_event event = NULL
+ a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo
+ b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo
+
+ cdef CLBlastStatusCode err
+ if dtype == np.dtype("float32"):
+ err = CLBlastSgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_float>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
+ elif dtype == np.dtype("float64"):
+ err = CLBlastDgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
+ elif dtype == np.dtype("complex64"):
+ err = CLBlastCgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float2>cl_float2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_float2>cl_float2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
+ elif dtype == np.dtype("complex128"):
+ err = CLBlastZgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
+ else:
+ raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+ if err != CLBlastSuccess:
+ raise RuntimeError("PyCLBlast: 'CLBlastXgemm' failed: %s" % get_status_message(err))
+ return cl.Event.from_int_ptr(<size_t>event)
+
+####################################################################################################