diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/pyclblast/src/pyclblast.pyx | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx index 7c2d6736..9115240f 100644 --- a/src/pyclblast/src/pyclblast.pyx +++ b/src/pyclblast/src/pyclblast.pyx @@ -2090,16 +2090,20 @@ from libc.stdlib cimport malloc, free from libc.string cimport strdup cdef extern from "clblast_c.h": - ctypedef void* cl_device_id + ctypedef struct _cl_device_id: + pass + ctypedef _cl_device_id* cl_device_id + CLBlastStatusCode CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name, const CLBlastPrecision precision, const size_t num_parameters, const char** parameters_names, const size_t* parameters_values) -def override_parameters(ctx, kernel_name, precision, parameters): +def override_parameters(device, kernel_name, precision, parameters): """ precision = 16, 32, 64, 3232, 6464 kernel name = unicode string parameters = """ - + + cdef cl_device_id device_id = <cl_device_id><size_t>device.int_ptr cdef size_t n = len(parameters) cdef const char **parameter_names = <const char**> malloc(n * sizeof(char*)) cdef size_t *parameter_values = <size_t*> malloc(n * sizeof(size_t)) @@ -2110,7 +2114,7 @@ def override_parameters(ctx, kernel_name, precision, parameters): parameter_names[i] = strdup(k.encode('ascii')) parameter_values[i] = v - err = CLBlastOverrideParameters((<cl_device_id> ctx.devices[0].ptr), kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values) + err = CLBlastOverrideParameters(device_id, kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values) if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'OverrideParameters' failed: %s" % get_status_message(err)) |