summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkodonell <kodonell@users.noreply.github.com>2018-03-09 15:27:33 +1300
committerkodonell <kodonell@users.noreply.github.com>2018-03-09 15:27:33 +1300
commit54a4b871b3f5fcbcfaa8f6bc7c56c8664527dd04 (patch)
treea6b92fd251a63a97f9b9f274788ef3e9802e6a36
parent269bddbf34e5cad00f3845d1a68974420997a040 (diff)
initial add of override parameters to pyclblast - cython not complaining, but segfault
-rw-r--r--src/pyclblast/src/pyclblast.pyx36
1 files changed, 35 insertions, 1 deletions
diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx
index 7bab3531..7c2d6736 100644
--- a/src/pyclblast/src/pyclblast.pyx
+++ b/src/pyclblast/src/pyclblast.pyx
@@ -1,3 +1,4 @@
+#distutils: language = c++
#cython: binding=True
####################################################################################################
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0.
@@ -13,7 +14,6 @@
import numpy as np
import pyopencl as cl
from pyopencl.array import Array
-
from libcpp cimport bool
####################################################################################################
@@ -2083,3 +2083,37 @@ def trsm(queue, m, n, a, b, a_ld, b_ld, alpha = 1.0, right_side = False, lower_t
return cl.Event.from_int_ptr(<size_t>event)
####################################################################################################
+# Overrides the parameters
+####################################################################################################
+
+from libc.stdlib cimport malloc, free
+from libc.string cimport strdup
+
+cdef extern from "clblast_c.h":
+ ctypedef void* cl_device_id
+ CLBlastStatusCode CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name, const CLBlastPrecision precision, const size_t num_parameters, const char** parameters_names, const size_t* parameters_values)
+
+def override_parameters(ctx, kernel_name, precision, parameters):
+ """
+ precision = 16, 32, 64, 3232, 6464
+ kernel name = unicode string
+ parameters =
+ """
+
+ cdef size_t n = len(parameters)
+ cdef const char **parameter_names = <const char**> malloc(n * sizeof(char*))
+ cdef size_t *parameter_values = <size_t*> malloc(n * sizeof(size_t))
+
+ # TODO: check mallocs
+
+ for i, (k, v) in enumerate(parameters.items()):
+ parameter_names[i] = strdup(k.encode('ascii'))
+ parameter_values[i] = v
+
+ err = CLBlastOverrideParameters((<cl_device_id> ctx.devices[0].ptr), kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values)
+ if err != CLBlastSuccess:
+ raise RuntimeError("PyCLBlast: 'OverrideParameters' failed: %s" % get_status_message(err))
+
+ # TODO: free etc.
+
+####################################################################################################