From c6056da0c85aa69ebb550c39509af011248027b4 Mon Sep 17 00:00:00 2001 From: kodonell Date: Sat, 10 Mar 2018 22:21:30 +1300 Subject: ok, device id working --- src/pyclblast/src/pyclblast.pyx | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'src/pyclblast') diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx index 7c2d6736..9115240f 100644 --- a/src/pyclblast/src/pyclblast.pyx +++ b/src/pyclblast/src/pyclblast.pyx @@ -2090,16 +2090,20 @@ from libc.stdlib cimport malloc, free from libc.string cimport strdup cdef extern from "clblast_c.h": - ctypedef void* cl_device_id + ctypedef struct _cl_device_id: + pass + ctypedef _cl_device_id* cl_device_id + CLBlastStatusCode CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name, const CLBlastPrecision precision, const size_t num_parameters, const char** parameters_names, const size_t* parameters_values) -def override_parameters(ctx, kernel_name, precision, parameters): +def override_parameters(device, kernel_name, precision, parameters): """ precision = 16, 32, 64, 3232, 6464 kernel name = unicode string parameters = """ - + + cdef cl_device_id device_id = device.int_ptr cdef size_t n = len(parameters) cdef const char **parameter_names = malloc(n * sizeof(char*)) cdef size_t *parameter_values = malloc(n * sizeof(size_t)) @@ -2110,7 +2114,7 @@ def override_parameters(ctx, kernel_name, precision, parameters): parameter_names[i] = strdup(k.encode('ascii')) parameter_values[i] = v - err = CLBlastOverrideParameters(( ctx.devices[0].ptr), kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values) + err = CLBlastOverrideParameters(device_id, kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values) if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'OverrideParameters' failed: %s" % get_status_message(err)) -- cgit v1.2.3