diff options
author | kodonell <kodonell@users.noreply.github.com> | 2018-03-10 22:21:30 +1300 |
---|---|---|
committer | kodonell <kodonell@users.noreply.github.com> | 2018-03-10 22:21:30 +1300 |
commit | c6056da0c85aa69ebb550c39509af011248027b4 (patch) | |
tree | f2220f7130b3afe780fc849618da429fa965d8a3 | |
parent | 54a4b871b3f5fcbcfaa8f6bc7c56c8664527dd04 (diff) |
ok, device id working
-rw-r--r-- | src/pyclblast/src/pyclblast.pyx | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx index 7c2d6736..9115240f 100644 --- a/src/pyclblast/src/pyclblast.pyx +++ b/src/pyclblast/src/pyclblast.pyx @@ -2090,16 +2090,20 @@ from libc.stdlib cimport malloc, free from libc.string cimport strdup cdef extern from "clblast_c.h": - ctypedef void* cl_device_id + ctypedef struct _cl_device_id: + pass + ctypedef _cl_device_id* cl_device_id + CLBlastStatusCode CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name, const CLBlastPrecision precision, const size_t num_parameters, const char** parameters_names, const size_t* parameters_values) -def override_parameters(ctx, kernel_name, precision, parameters): +def override_parameters(device, kernel_name, precision, parameters): """ precision = 16, 32, 64, 3232, 6464 kernel name = unicode string parameters = """ - + + cdef cl_device_id device_id = <cl_device_id><size_t>device.int_ptr cdef size_t n = len(parameters) cdef const char **parameter_names = <const char**> malloc(n * sizeof(char*)) cdef size_t *parameter_values = <size_t*> malloc(n * sizeof(size_t)) @@ -2110,7 +2114,7 @@ def override_parameters(ctx, kernel_name, precision, parameters): parameter_names[i] = strdup(k.encode('ascii')) parameter_values[i] = v - err = CLBlastOverrideParameters((<cl_device_id> ctx.devices[0].ptr), kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values) + err = CLBlastOverrideParameters(device_id, kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values) if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'OverrideParameters' failed: %s" % get_status_message(err)) |