diff options
author | kodonell <kodonell@users.noreply.github.com> | 2018-03-09 15:27:33 +1300 |
---|---|---|
committer | kodonell <kodonell@users.noreply.github.com> | 2018-03-09 15:27:33 +1300 |
commit | 54a4b871b3f5fcbcfaa8f6bc7c56c8664527dd04 (patch) | |
tree | a6b92fd251a63a97f9b9f274788ef3e9802e6a36 | |
parent | 269bddbf34e5cad00f3845d1a68974420997a040 (diff) |
initial add of override parameters to pyclblast - cython not complaining, but segfault
-rw-r--r-- | src/pyclblast/src/pyclblast.pyx | 36 |
1 files changed, 35 insertions, 1 deletions
diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx index 7bab3531..7c2d6736 100644 --- a/src/pyclblast/src/pyclblast.pyx +++ b/src/pyclblast/src/pyclblast.pyx @@ -1,3 +1,4 @@ +#distutils: language = c++ #cython: binding=True #################################################################################################### # This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. @@ -13,7 +14,6 @@ import numpy as np import pyopencl as cl from pyopencl.array import Array - from libcpp cimport bool #################################################################################################### @@ -2083,3 +2083,37 @@ def trsm(queue, m, n, a, b, a_ld, b_ld, alpha = 1.0, right_side = False, lower_t return cl.Event.from_int_ptr(<size_t>event) #################################################################################################### +# Overrides the parameters +#################################################################################################### + +from libc.stdlib cimport malloc, free +from libc.string cimport strdup + +cdef extern from "clblast_c.h": + ctypedef void* cl_device_id + CLBlastStatusCode CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name, const CLBlastPrecision precision, const size_t num_parameters, const char** parameters_names, const size_t* parameters_values) + +def override_parameters(ctx, kernel_name, precision, parameters): + """ + precision = 16, 32, 64, 3232, 6464 + kernel name = unicode string + parameters = + """ + + cdef size_t n = len(parameters) + cdef const char **parameter_names = <const char**> malloc(n * sizeof(char*)) + cdef size_t *parameter_values = <size_t*> malloc(n * sizeof(size_t)) + + # TODO: check mallocs + + for i, (k, v) in enumerate(parameters.items()): + parameter_names[i] = strdup(k.encode('ascii')) + parameter_values[i] = v + + err = CLBlastOverrideParameters((<cl_device_id> ctx.devices[0].ptr), kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values) + if err != CLBlastSuccess: + raise RuntimeError("PyCLBlast: 'OverrideParameters' failed: %s" % get_status_message(err)) + + # TODO: free etc. + +#################################################################################################### |