diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2020-05-10 18:23:41 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-10 18:23:41 +0200 |
commit | 9abc4167854f77e7982957232b0c23483a79c97a (patch) | |
tree | 07875a330ec015ef3dff37e5bf1090917440e351 | |
parent | 5f4b3ffcf7c8f90eee33a8504ede00ea52f79c0e (diff) | |
parent | 0870e76fbacaf88164f271b1ede79b8298214aff (diff) |
Merge pull request #386 from CNugteren/CLBlast-384-pyclblast-missing-routines
PyCLBlast: add missing batched routines
-rw-r--r-- | CHANGELOG | 1 | ||||
-rw-r--r-- | scripts/generator/generator/pyclblast.py | 70 | ||||
-rw-r--r-- | scripts/generator/generator/routine.py | 43 | ||||
-rw-r--r-- | src/pyclblast/samples/saxpybatched.py | 46 | ||||
-rw-r--r-- | src/pyclblast/setup.py | 2 | ||||
-rw-r--r-- | src/pyclblast/src/pyclblast.pyx | 257 |
6 files changed, 401 insertions, 18 deletions
@@ -1,5 +1,6 @@ Development version (next version) - Changed XAMAX/XAMIN to more likely return first rather than last min/max index, updated API docs +- Added batched routines to pyclblast - Various minor fixes and enhancements Version 1.5.1 diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py index 47eb2eb4..b7ec348e 100644 --- a/scripts/generator/generator/pyclblast.py +++ b/scripts/generator/generator/pyclblast.py @@ -22,6 +22,16 @@ def to_np_dtype(flavour): }[flavour.precision_name] +def cl_type(flavour): + return { + "S": "cl_float", + "D": "cl_double", + "C": "cl_float2", + "Z": "cl_double2", + "H": "cl_half", + }[flavour.precision_name] + + def scalar_cython_conversion(scalar, flavour): scalar_type = flavour.alpha_cl if scalar == "alpha" else flavour.beta_cl if scalar_type == "float": @@ -39,7 +49,9 @@ def scalar_cython_conversion(scalar, flavour): def generate_pyx(routine): result = "" - if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3"]: + if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3", "x"]: + if routine.level == "x" and routine.batched == 0: + return result # level-X routines that are non-batched are not supported at the moment indent = " " result += SEPARATOR + NL @@ -80,6 +92,33 @@ def generate_pyx(routine): result += buf + ", \"" + buf + "\")" + NL result += NL + # Batched checks + if routine.batched == 1: # batched but not strided-batched + lists = [b + "_offsets" for b in buffers] + [s + "s" for s in routine.scalars] + result += indent + "if " + " != ".join(["len(" + l + ")" for l in lists]) + ":" + NL + result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: length of batch-sized arguments " + ", ".join(lists) + " should be equal\")" + NL + result += indent + "batch_count = len(" + lists[0] + ")" + NL + result += NL + + # Batched list to pointer conversions + for buf in buffers: + result += indent + "cdef size_t *" + buf + "_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t))" + NL + result += indent + "for i in range(batch_count):" + NL + result += indent + indent + "" + buf + "_offsets_c[i] = " + buf + "_offsets[i]" + NL + for scalar in routine.scalars: + result += indent + "cdef void *" + scalar + "s_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype]))" + NL + result += indent + "for i in range(batch_count):" + NL + if_prefix = "" + for flavour in routine.flavours: + if flavour.precision_name in ["S", "D", "C", "Z", "H"]: + np_dtype = to_np_dtype(flavour) + result += indent + indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL + scalar_converted = scalar_cython_conversion(scalar + "s[i]", flavour) + result += indent + indent + indent + "(<" + cl_type(flavour) + "*>" + scalar + "s_c)[i] = " + scalar_converted + NL + if_prefix = "el" + + result += NL + # Buffer transformation for buf in buffers: result += indent + "cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL @@ -108,11 +147,22 @@ def generate_pyx(routine): for flavour in routine.flavours: if flavour.precision_name in ["S", "D", "C", "Z", "H"]: np_dtype = to_np_dtype(flavour) - argument_names = [x. - replace("layout", "CLBlastLayoutRowMajor"). - replace("alpha", scalar_cython_conversion("alpha", flavour)). - replace("beta", scalar_cython_conversion("beta", flavour)) - for x in routine.arguments()] + if routine.batched != 1: # regular or strided-batched + argument_names = [x. + replace("layout", "CLBlastLayoutRowMajor"). + replace("alpha", scalar_cython_conversion("alpha", flavour)). + replace("beta", scalar_cython_conversion("beta", flavour)) + for x in routine.arguments()] + else: # batched but not strided-batched + argument_names = [x. + replace("layout", "CLBlastLayoutRowMajor"). + replace("_cpp", "_c"). + replace("_offsets", "_offsets_c"). + replace("alphas_c", "<" + cl_type(flavour) + "*>alphas_c"). + replace("betas_c", "<" + cl_type(flavour) + "*>betas_c") + for x in routine.arguments()] + if routine.batched > 0: + argument_names.append("batch_count") result += indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL result += indent + indent + "err = CLBlast" + flavour.name + routine.plain_name() result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL @@ -120,6 +170,14 @@ def generate_pyx(routine): result += indent + "else:" + NL result += indent + indent + "raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL + result += NL + + # Cleaning up + if routine.batched == 1: # batched but not strided-batched + for array in [b + "_offset" for b in buffers] + routine.scalars: + result += indent + "PyMem_Free(" + array + "s_c)" + NL + result += NL + result += indent + "if err != CLBlastSuccess:" + NL result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL result += indent + "return cl.Event.from_int_ptr(<size_t>event)" + NL diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 3b5a6b76..8b6ab57f 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -825,17 +825,37 @@ class Routine: """Arguments for the Python wrapper pyclblast""" result = list() result.extend(self.sizes) + if self.batched == 2: # strided batched + result.append("batch_count") buffers = self.inputs + self.outputs result.extend(buffers[:]) - for buf in buffers: - if buf in self.buffers_matrix(): - result.append(buf + "_ld") - for buf in buffers: - if buf in self.buffers_vector(): - result.append(buf + "_inc = 1") - for scalar in self.scalars: - default = "1.0" if scalar == "alpha" else "0.0" - result.append(scalar + " = " + default) + if self.batched != 1: # regular or strided-batched + for buf in buffers: + if buf in self.buffers_matrix(): + result.append(buf + "_ld") + for buf in buffers: + if buf in self.buffers_vector(): + result.append(buf + "_inc = 1") + if self.batched == 2: # strided batched + for buf in buffers: + if buf in self.buffers_matrix(): + result.append(buf + "_stride") + for scalar in self.scalars: + if scalar != "": + default = "1.0" if scalar == "alpha" else "0.0" + result.append(scalar + " = " + default) + else: # batched but not strided-batched + for scalar in self.scalars: + result.append(scalar + "s") + for buf in buffers: + if buf in self.buffers_matrix(): + result.append(buf + "_ld") + for buf in buffers: + if buf in self.buffers_vector() + self.buffers_matrix(): + result.append(buf + "_offsets") + for buf in buffers: + if buf in self.buffers_vector(): + result.append(buf + "_inc = 1") for option in self.options: if option == "a_transpose": result.append("a_transp = False") @@ -849,8 +869,9 @@ class Routine: result.append("lower_triangle = False") if option == "diagonal": result.append("unit_diagonal = False") - for buf in buffers: - result.append(buf + "_offset = 0") + if self.batched != 1: + for buf in buffers: + result.append(buf + "_offset = 0") return result def requirements_doc(self): diff --git a/src/pyclblast/samples/saxpybatched.py b/src/pyclblast/samples/saxpybatched.py new file mode 100644 index 00000000..fa523945 --- /dev/null +++ b/src/pyclblast/samples/saxpybatched.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python + +# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. +# This file follows the PEP8 Python style guide and uses a max-width of 100 characters per line. +# +# Author(s): +# Cedric Nugteren <www.cedricnugteren.nl> + +import numpy as np +import pyopencl as cl +from pyopencl.array import Array +import pyclblast + +# Settings for this sample: +batch_count = 2 +dtype = 'float32' +alphas = [1.5, 1.0] +n = 4 + +print("# Setting up OpenCL") +ctx = cl.create_some_context() +queue = cl.CommandQueue(ctx) + +print("# Setting up Numpy arrays") +x = np.random.rand(n * batch_count).astype(dtype=dtype) +y = np.random.rand(n * batch_count).astype(dtype=dtype) + +print("# Batch offsets: next after each other") +x_offsets = [0, n] +y_offsets = [0, n] + +print("# Setting up OpenCL arrays") +clx = Array(queue, x.shape, x.dtype) +cly = Array(queue, y.shape, y.dtype) +clx.set(x) +cly.set(y) + +print("# Example level-1 batched operation: AXPY-batched") +assert len(alphas) == len(x_offsets) == len(y_offsets) == batch_count +pyclblast.axpyBatched(queue, n, clx, cly, alphas, x_offsets, y_offsets) +queue.finish() + +print("# Full result for vector y: %s" % str(cly.get())) +for i in range(batch_count): + result = alphas[i] * x[x_offsets[i]:x_offsets[i] + n] + y[y_offsets[i]:y_offsets[i] + n] + print("# Expected result batch #%d: %s" % (i, str(result))) diff --git a/src/pyclblast/setup.py b/src/pyclblast/setup.py index 1c1bf3ab..bcc966ed 100644 --- a/src/pyclblast/setup.py +++ b/src/pyclblast/setup.py @@ -22,7 +22,7 @@ ext_modules.append( setup( name="pyclblast", - version="1.2.0", + version="1.3.0", author="Cedric Nugteren", author_email="web@cedricnugteren.nl", url="https://github.com/CNugteren/CLBlast/blob/master/src/pyclblast", diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx index 14efcf8a..eb46649f 100644 --- a/src/pyclblast/src/pyclblast.pyx +++ b/src/pyclblast/src/pyclblast.pyx @@ -364,6 +364,7 @@ def swap(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0): err = CLBlastHswap(n, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXswap' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -405,6 +406,7 @@ def scal(queue, n, x, x_inc = 1, alpha = 1.0, x_offset = 0): err = CLBlastHscal(n, <cl_half>alpha, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXscal' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -448,6 +450,7 @@ def copy(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0): err = CLBlastHcopy(n, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXcopy' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -491,6 +494,7 @@ def axpy(queue, n, x, y, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset = 0, y_offs err = CLBlastHaxpy(n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXaxpy' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -530,6 +534,7 @@ def dot(queue, n, x, y, dot, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0, d err = CLBlastHdot(n, dot_buffer, dot_offset, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXdot' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -566,6 +571,7 @@ def dotu(queue, n, x, y, dot, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0, err = CLBlastZdotu(n, dot_buffer, dot_offset, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXdotu' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -602,6 +608,7 @@ def dotc(queue, n, x, y, dot, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0, err = CLBlastZdotc(n, dot_buffer, dot_offset, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXdotc' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -645,6 +652,7 @@ def nrm2(queue, n, x, nrm2, x_inc = 1, x_offset = 0, nrm2_offset = 0): err = CLBlastHnrm2(n, nrm2_buffer, nrm2_offset, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXnrm2' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -688,6 +696,7 @@ def asum(queue, n, x, asum, x_inc = 1, x_offset = 0, asum_offset = 0): err = CLBlastHasum(n, asum_buffer, asum_offset, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXasum' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -731,6 +740,7 @@ def sum(queue, n, x, sum, x_inc = 1, x_offset = 0, sum_offset = 0): err = CLBlastHsum(n, sum_buffer, sum_offset, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXsum' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -774,6 +784,7 @@ def amax(queue, n, x, imax, x_inc = 1, x_offset = 0, imax_offset = 0): err = CLBlastiHamax(n, imax_buffer, imax_offset, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXamax' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -817,6 +828,7 @@ def amin(queue, n, x, imin, x_inc = 1, x_offset = 0, imin_offset = 0): err = CLBlastiHamin(n, imin_buffer, imin_offset, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXamin' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -860,6 +872,7 @@ def max(queue, n, x, imax, x_inc = 1, x_offset = 0, imax_offset = 0): err = CLBlastiHmax(n, imax_buffer, imax_offset, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXmax' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -903,6 +916,7 @@ def min(queue, n, x, imin, x_inc = 1, x_offset = 0, imin_offset = 0): err = CLBlastiHmin(n, imin_buffer, imin_offset, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXmin' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -949,6 +963,7 @@ def gemv(queue, m, n, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0 err = CLBlastHgemv(CLBlastLayoutRowMajor, a_transpose, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXgemv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -995,6 +1010,7 @@ def gbmv(queue, m, n, kl, ku, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, err = CLBlastHgbmv(CLBlastLayoutRowMajor, a_transpose, m, n, kl, ku, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXgbmv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1032,6 +1048,7 @@ def hemv(queue, n, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0.0, err = CLBlastZhemv(CLBlastLayoutRowMajor, triangle, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_double2>cl_double2(x=beta.real,y=beta.imag), y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXhemv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1069,6 +1086,7 @@ def hbmv(queue, n, k, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0 err = CLBlastZhbmv(CLBlastLayoutRowMajor, triangle, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_double2>cl_double2(x=beta.real,y=beta.imag), y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXhbmv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1106,6 +1124,7 @@ def hpmv(queue, n, ap, x, y, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0. err = CLBlastZhpmv(CLBlastLayoutRowMajor, triangle, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), ap_buffer, ap_offset, x_buffer, x_offset, x_inc, <cl_double2>cl_double2(x=beta.real,y=beta.imag), y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXhpmv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1146,6 +1165,7 @@ def symv(queue, n, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0.0, err = CLBlastHsymv(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXsymv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1186,6 +1206,7 @@ def sbmv(queue, n, k, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0 err = CLBlastHsbmv(CLBlastLayoutRowMajor, triangle, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXsbmv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1226,6 +1247,7 @@ def spmv(queue, n, ap, x, y, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0. err = CLBlastHspmv(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, ap_buffer, ap_offset, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXspmv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1272,6 +1294,7 @@ def trmv(queue, n, a, x, a_ld, x_inc = 1, lower_triangle = False, a_transp = Fal err = CLBlastHtrmv(CLBlastLayoutRowMajor, triangle, a_transpose, diagonal, n, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXtrmv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1318,6 +1341,7 @@ def tbmv(queue, n, k, a, x, a_ld, x_inc = 1, lower_triangle = False, a_transp = err = CLBlastHtbmv(CLBlastLayoutRowMajor, triangle, a_transpose, diagonal, n, k, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXtbmv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1364,6 +1388,7 @@ def tpmv(queue, n, ap, x, ap_ld, x_inc = 1, lower_triangle = False, a_transp = F err = CLBlastHtpmv(CLBlastLayoutRowMajor, triangle, a_transpose, diagonal, n, ap_buffer, ap_offset, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXtpmv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1407,6 +1432,7 @@ def trsv(queue, n, a, x, a_ld, x_inc = 1, lower_triangle = False, a_transp = Fal err = CLBlastZtrsv(CLBlastLayoutRowMajor, triangle, a_transpose, diagonal, n, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXtrsv' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1446,6 +1472,7 @@ def ger(queue, m, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset err = CLBlastHger(CLBlastLayoutRowMajor, m, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXger' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1482,6 +1509,7 @@ def geru(queue, m, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset err = CLBlastZgeru(CLBlastLayoutRowMajor, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXgeru' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1518,6 +1546,7 @@ def gerc(queue, m, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset err = CLBlastZgerc(CLBlastLayoutRowMajor, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXgerc' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1553,6 +1582,7 @@ def her(queue, n, x, a, a_ld, x_inc = 1, alpha = 1.0, lower_triangle = False, x_ err = CLBlastZher(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, x_buffer, x_offset, x_inc, a_buffer, a_offset, a_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXher' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1588,6 +1618,7 @@ def hpr(queue, n, x, ap, ap_ld, x_inc = 1, alpha = 1.0, lower_triangle = False, err = CLBlastZhpr(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, x_buffer, x_offset, x_inc, ap_buffer, ap_offset, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXhpr' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1625,6 +1656,7 @@ def her2(queue, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_trian err = CLBlastZher2(CLBlastLayoutRowMajor, triangle, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXher2' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1662,6 +1694,7 @@ def hpr2(queue, n, x, y, ap, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_tri err = CLBlastZhpr2(CLBlastLayoutRowMajor, triangle, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, ap_buffer, ap_offset, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXhpr2' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1700,6 +1733,7 @@ def syr(queue, n, x, a, a_ld, x_inc = 1, alpha = 1.0, lower_triangle = False, x_ err = CLBlastHsyr(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, a_buffer, a_offset, a_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXsyr' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1738,6 +1772,7 @@ def spr(queue, n, x, ap, ap_ld, x_inc = 1, alpha = 1.0, lower_triangle = False, err = CLBlastHspr(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, ap_buffer, ap_offset, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXspr' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1778,6 +1813,7 @@ def syr2(queue, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_trian err = CLBlastHsyr2(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXsyr2' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1818,6 +1854,7 @@ def spr2(queue, n, x, y, ap, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_tri err = CLBlastHspr2(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, ap_buffer, ap_offset, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXspr2' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1865,6 +1902,7 @@ def gemm(queue, m, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, a_t err = CLBlastHgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXgemm' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1912,6 +1950,7 @@ def symm(queue, m, n, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, right_ err = CLBlastHsymm(CLBlastLayoutRowMajor, side, triangle, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXsymm' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1950,6 +1989,7 @@ def hemm(queue, m, n, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, right_ err = CLBlastZhemm(CLBlastLayoutRowMajor, side, triangle, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXhemm' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -1995,6 +2035,7 @@ def syrk(queue, n, k, a, c, a_ld, c_ld, alpha = 1.0, beta = 0.0, lower_triangle err = CLBlastHsyrk(CLBlastLayoutRowMajor, triangle, a_transpose, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXsyrk' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -2031,6 +2072,7 @@ def herk(queue, n, k, a, c, a_ld, c_ld, alpha = 1.0, beta = 0.0, lower_triangle err = CLBlastZherk(CLBlastLayoutRowMajor, triangle, a_transpose, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, <cl_double>beta, c_buffer, c_offset, c_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXherk' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -2078,6 +2120,7 @@ def syr2k(queue, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, lower err = CLBlastHsyr2k(CLBlastLayoutRowMajor, triangle, ab_transpose, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXsyr2k' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -2116,6 +2159,7 @@ def her2k(queue, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, lower err = CLBlastZher2k(CLBlastLayoutRowMajor, triangle, ab_transpose, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double>beta, c_buffer, c_offset, c_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXher2k' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -2163,6 +2207,7 @@ def trmm(queue, m, n, a, b, a_ld, b_ld, alpha = 1.0, right_side = False, lower_t err = CLBlastHtrmm(CLBlastLayoutRowMajor, side, triangle, a_transpose, diagonal, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXtrmm' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) @@ -2207,11 +2252,223 @@ def trsm(queue, m, n, a, b, a_ld, b_ld, alpha = 1.0, right_side = False, lower_t err = CLBlastZtrsm(CLBlastLayoutRowMajor, side, triangle, a_transpose, diagonal, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, &command_queue, &event) else: raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'CLBlastXtrsm' failed: %s" % get_status_message(err)) return cl.Event.from_int_ptr(<size_t>event) #################################################################################################### +# Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED +#################################################################################################### + +cdef extern from "clblast_c.h": + CLBlastStatusCode CLBlastSaxpyBatched(const size_t n, const float *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastDaxpyBatched(const size_t n, const double *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastCaxpyBatched(const size_t n, const cl_float2 *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastZaxpyBatched(const size_t n, const cl_double2 *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastHaxpyBatched(const size_t n, const cl_half *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event) + +def axpyBatched(queue, n, x, y, alphas, x_offsets, y_offsets, x_inc = 1, y_inc = 1): + """ + xAXPYBATCHED: Batched version of AXPY + """ + + dtype = check_dtype([x, y], ["float32", "float64", "complex64", "complex128", "float16"]) + check_vector(x, "x") + check_vector(y, "y") + + if len(x_offsets) != len(y_offsets) != len(alphas): + raise RuntimeError("PyCLBlast: 'CLBlastXaxpyBatched' failed: length of batch-sized arguments x_offsets, y_offsets, alphas should be equal") + batch_count = len(x_offsets) + + cdef size_t *x_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t)) + for i in range(batch_count): + x_offsets_c[i] = x_offsets[i] + cdef size_t *y_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t)) + for i in range(batch_count): + y_offsets_c[i] = y_offsets[i] + cdef void *alphas_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype])) + for i in range(batch_count): + if dtype == np.dtype("float32"): + (<cl_float*>alphas_c)[i] = <cl_float>alphas[i] + elif dtype == np.dtype("float64"): + (<cl_double*>alphas_c)[i] = <cl_double>alphas[i] + elif dtype == np.dtype("complex64"): + (<cl_float2*>alphas_c)[i] = <cl_float2>cl_float2(x=alphas[i].real,y=alphas[i].imag) + elif dtype == np.dtype("complex128"): + (<cl_double2*>alphas_c)[i] = <cl_double2>cl_double2(x=alphas[i].real,y=alphas[i].imag) + elif dtype == np.dtype("float16"): + (<cl_half*>alphas_c)[i] = <cl_half>alphas[i] + + cdef cl_mem x_buffer = <cl_mem><size_t>x.base_data.int_ptr + cdef cl_mem y_buffer = <cl_mem><size_t>y.base_data.int_ptr + + cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr + cdef cl_event event = NULL + + cdef CLBlastStatusCode err + if dtype == np.dtype("float32"): + err = CLBlastSaxpyBatched(n, <cl_float*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event) + elif dtype == np.dtype("float64"): + err = CLBlastDaxpyBatched(n, <cl_double*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event) + elif dtype == np.dtype("complex64"): + err = CLBlastCaxpyBatched(n, <cl_float2*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event) + elif dtype == np.dtype("complex128"): + err = CLBlastZaxpyBatched(n, <cl_double2*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event) + elif dtype == np.dtype("float16"): + err = CLBlastHaxpyBatched(n, <cl_half*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event) + else: + raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + + PyMem_Free(x_offsets_c) + PyMem_Free(y_offsets_c) + PyMem_Free(alphas_c) + + if err != CLBlastSuccess: + raise RuntimeError("PyCLBlast: 'CLBlastXaxpyBatched' failed: %s" % get_status_message(err)) + return cl.Event.from_int_ptr(<size_t>event) + +#################################################################################################### +# Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED +#################################################################################################### + +cdef extern from "clblast_c.h": + CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const float *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const float *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const double *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const double *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_float2 *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const cl_float2 *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_double2 *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const cl_double2 *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_half *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const cl_half *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event) + +def gemmBatched(queue, m, n, k, a, b, c, alphas, betas, a_ld, b_ld, c_ld, a_offsets, b_offsets, c_offsets, a_transp = False, b_transp = False): + """ + xGEMMBATCHED: Batched version of GEMM + """ + + dtype = check_dtype([a, b, c], ["float32", "float64", "complex64", "complex128", "float16"]) + check_matrix(a, "a") + check_matrix(b, "b") + check_matrix(c, "c") + + if len(a_offsets) != len(b_offsets) != len(c_offsets) != len(alphas) != len(betas): + raise RuntimeError("PyCLBlast: 'CLBlastXgemmBatched' failed: length of batch-sized arguments a_offsets, b_offsets, c_offsets, alphas, betas should be equal") + batch_count = len(a_offsets) + + cdef size_t *a_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t)) + for i in range(batch_count): + a_offsets_c[i] = a_offsets[i] + cdef size_t *b_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t)) + for i in range(batch_count): + b_offsets_c[i] = b_offsets[i] + cdef size_t *c_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t)) + for i in range(batch_count): + c_offsets_c[i] = c_offsets[i] + cdef void *alphas_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype])) + for i in range(batch_count): + if dtype == np.dtype("float32"): + (<cl_float*>alphas_c)[i] = <cl_float>alphas[i] + elif dtype == np.dtype("float64"): + (<cl_double*>alphas_c)[i] = <cl_double>alphas[i] + elif dtype == np.dtype("complex64"): + (<cl_float2*>alphas_c)[i] = <cl_float2>cl_float2(x=alphas[i].real,y=alphas[i].imag) + elif dtype == np.dtype("complex128"): + (<cl_double2*>alphas_c)[i] = <cl_double2>cl_double2(x=alphas[i].real,y=alphas[i].imag) + elif dtype == np.dtype("float16"): + (<cl_half*>alphas_c)[i] = <cl_half>alphas[i] + cdef void *betas_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype])) + for i in range(batch_count): + if dtype == np.dtype("float32"): + (<cl_float*>betas_c)[i] = <cl_float>betas[i] + elif dtype == np.dtype("float64"): + (<cl_double*>betas_c)[i] = <cl_double>betas[i] + elif dtype == np.dtype("complex64"): + (<cl_float2*>betas_c)[i] = <cl_float2>cl_float2(x=betas[i].real,y=betas[i].imag) + elif dtype == np.dtype("complex128"): + (<cl_double2*>betas_c)[i] = <cl_double2>cl_double2(x=betas[i].real,y=betas[i].imag) + elif dtype == np.dtype("float16"): + (<cl_half*>betas_c)[i] = <cl_half>betas[i] + + cdef cl_mem a_buffer = <cl_mem><size_t>a.base_data.int_ptr + cdef cl_mem b_buffer = <cl_mem><size_t>b.base_data.int_ptr + cdef cl_mem c_buffer = <cl_mem><size_t>c.base_data.int_ptr + + cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr + cdef cl_event event = NULL + a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo + b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo + + cdef CLBlastStatusCode err + if dtype == np.dtype("float32"): + err = CLBlastSgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_float*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event) + elif dtype == np.dtype("float64"): + err = CLBlastDgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_double*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event) + elif dtype == np.dtype("complex64"): + err = CLBlastCgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float2*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_float2*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event) + elif dtype == np.dtype("complex128"): + err = CLBlastZgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_double2*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event) + elif dtype == np.dtype("float16"): + err = CLBlastHgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_half*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event) + else: + raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + + PyMem_Free(a_offsets_c) + PyMem_Free(b_offsets_c) + PyMem_Free(c_offsets_c) + PyMem_Free(alphas_c) + PyMem_Free(betas_c) + + if err != CLBlastSuccess: + raise RuntimeError("PyCLBlast: 'CLBlastXgemmBatched' failed: %s" % get_status_message(err)) + return cl.Event.from_int_ptr(<size_t>event) + +#################################################################################################### +# StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED +#################################################################################################### + +cdef extern from "clblast_c.h": + CLBlastStatusCode CLBlastSgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const float alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const float beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastDgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const double alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const double beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastCgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_float2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const cl_float2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastZgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_double2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event) + CLBlastStatusCode CLBlastHgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_half alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const cl_half beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event) + +def gemmStridedBatched(queue, m, n, k, batch_count, a, b, c, a_ld, b_ld, c_ld, a_stride, b_stride, c_stride, alpha = 1.0, beta = 0.0, a_transp = False, b_transp = False, a_offset = 0, b_offset = 0, c_offset = 0): + """ + xGEMMSTRIDEDBATCHED: StridedBatched version of GEMM + """ + + dtype = check_dtype([a, b, c], ["float32", "float64", "complex64", "complex128", "float16"]) + check_matrix(a, "a") + check_matrix(b, "b") + check_matrix(c, "c") + + cdef cl_mem a_buffer = <cl_mem><size_t>a.base_data.int_ptr + cdef cl_mem b_buffer = <cl_mem><size_t>b.base_data.int_ptr + cdef cl_mem c_buffer = <cl_mem><size_t>c.base_data.int_ptr + + cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr + cdef cl_event event = NULL + a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo + b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo + + cdef CLBlastStatusCode err + if dtype == np.dtype("float32"): + err = CLBlastSgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float>alpha, a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_float>beta, c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event) + elif dtype == np.dtype("float64"): + err = CLBlastDgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_double>beta, c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event) + elif dtype == np.dtype("complex64"): + err = CLBlastCgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float2>cl_float2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_float2>cl_float2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event) + elif dtype == np.dtype("complex128"): + err = CLBlastZgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event) + elif dtype == np.dtype("float16"): + err = CLBlastHgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_half>beta, c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event) + else: + raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype) + + if err != CLBlastSuccess: + raise RuntimeError("PyCLBlast: 'CLBlastXgemmStridedBatched' failed: %s" % get_status_message(err)) + return cl.Event.from_int_ptr(<size_t>event) + +#################################################################################################### # Overrides the parameters #################################################################################################### |