diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-02-18 16:33:20 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-02-18 16:33:20 +0100 |
commit | e1bfb4082716ef9619a13e9985aca9ef28cf4cbf (patch) | |
tree | 94178e2501add7fd8a3ad683ef3ba2ed1e7cafd8 /src/pyclblast | |
parent | eb85f6b514b285d7dde1ac02b97b7581a46ff21d (diff) |
Added GEMM to the Python wrapper
Diffstat (limited to 'src/pyclblast')
-rw-r--r-- | src/pyclblast/pyclblast/pyclblast.pyx | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/src/pyclblast/pyclblast/pyclblast.pyx b/src/pyclblast/pyclblast/pyclblast.pyx index a090d367..2f6ebba2 100644 --- a/src/pyclblast/pyclblast/pyclblast.pyx +++ b/src/pyclblast/pyclblast/pyclblast.pyx @@ -323,3 +323,43 @@ def swap(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0): return cl.Event.from_int_ptr(<size_t>event) #################################################################################################### +# General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM +#################################################################################################### + +cdef extern from "clblast_c.h": + CLBlastStatusCode CLBlastSgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const float alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const float beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastDgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const double alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const double beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastCgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_float2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const cl_float2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastZgemm(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_double2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld,cl_command_queue* queue, cl_event* event) + +def gemm(queue, m, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, a_transp = False, b_transp = False, a_offset = 0, b_offset = 0, c_offset = 0): + dtype = check_dtype([a, b, c], ["float32", "float64", "complex64", "complex128"]) + check_matrix(a, "a") + check_matrix(b, "b") + check_matrix(c, "c") + + cdef cl_mem a_buffer = <cl_mem><size_t>a.base_data.int_ptr + cdef cl_mem b_buffer = <cl_mem><size_t>b.base_data.int_ptr + cdef cl_mem c_buffer = <cl_mem><size_t>c.base_data.int_ptr + + cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr + cdef cl_event event = NULL + a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo + b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo + + cdef CLBlastStatusCode err + if dtype == np.dtype("float32"): + err = CLBlastSgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_float>beta, c_buffer, c_offset, c_ld, &command_queue, &event) + elif dtype == np.dtype("float64"): + err = CLBlastDgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double>beta, c_buffer, c_offset, c_ld, &command_queue, &event) + elif dtype == np.dtype("complex64"): + err = CLBlastCgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float2>cl_float2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_float2>cl_float2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event) + elif dtype == np.dtype("complex128"): + err = CLBlastZgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event) + else: + raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: + raise RuntimeError("PyCLBlast: 'CLBlastXgemm' failed: %s" % get_status_message(err)) + return cl.Event.from_int_ptr(<size_t>event) + +#################################################################################################### |