summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkodonell <kodonell@users.noreply.github.com>2018-03-10 22:21:30 +1300
committerkodonell <kodonell@users.noreply.github.com>2018-03-10 22:21:30 +1300
commitc6056da0c85aa69ebb550c39509af011248027b4 (patch)
treef2220f7130b3afe780fc849618da429fa965d8a3
parent54a4b871b3f5fcbcfaa8f6bc7c56c8664527dd04 (diff)
ok, device id working
-rw-r--r--src/pyclblast/src/pyclblast.pyx12
1 files changed, 8 insertions, 4 deletions
diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx
index 7c2d6736..9115240f 100644
--- a/src/pyclblast/src/pyclblast.pyx
+++ b/src/pyclblast/src/pyclblast.pyx
@@ -2090,16 +2090,20 @@ from libc.stdlib cimport malloc, free
from libc.string cimport strdup
cdef extern from "clblast_c.h":
- ctypedef void* cl_device_id
+ ctypedef struct _cl_device_id:
+ pass
+ ctypedef _cl_device_id* cl_device_id
+
CLBlastStatusCode CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name, const CLBlastPrecision precision, const size_t num_parameters, const char** parameters_names, const size_t* parameters_values)
-def override_parameters(ctx, kernel_name, precision, parameters):
+def override_parameters(device, kernel_name, precision, parameters):
"""
precision = 16, 32, 64, 3232, 6464
kernel name = unicode string
parameters =
"""
-
+
+ cdef cl_device_id device_id = <cl_device_id><size_t>device.int_ptr
cdef size_t n = len(parameters)
cdef const char **parameter_names = <const char**> malloc(n * sizeof(char*))
cdef size_t *parameter_values = <size_t*> malloc(n * sizeof(size_t))
@@ -2110,7 +2114,7 @@ def override_parameters(ctx, kernel_name, precision, parameters):
parameter_names[i] = strdup(k.encode('ascii'))
parameter_values[i] = v
- err = CLBlastOverrideParameters((<cl_device_id> ctx.devices[0].ptr), kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values)
+ err = CLBlastOverrideParameters(device_id, kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values)
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'OverrideParameters' failed: %s" % get_status_message(err))