summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/routine.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r--scripts/generator/generator/routine.py22
1 files changed, 20 insertions, 2 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index a4e682c2..2fa5e9d6 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -349,6 +349,13 @@ class Routine:
return [", ".join(definitions)]
return []
+ def options_def_c(self):
+ """As above, but now for the C API"""
+ if self.options:
+ definitions = ["const CLBlast" + convert.option_to_clblast(o) + " " + o for o in self.options]
+ return [", ".join(definitions)]
+ return []
+
def options_def_wrapper_clblas(self):
"""As above, but now using clBLAS data-types"""
if self.options:
@@ -453,6 +460,17 @@ class Routine:
list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) +
list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])))
+ def arguments_def_c(self, flavour):
+ """As above, but for the C API"""
+ return (self.options_def_c() + self.sizes_def() +
+ list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_def("alpha", flavour) +
+ list(chain(*[self.buffer_def(b) for b in self.buffers_first()])) +
+ self.scalar_def("beta", flavour) +
+ list(chain(*[self.buffer_def(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])))
+
def arguments_def_wrapper_clblas(self, flavour):
"""As above, but clBLAS wrapper plain data-types"""
return (self.options_def_wrapper_clblas() + self.sizes_def() +
@@ -523,8 +541,8 @@ class Routine:
def routine_header_c(self, flavour, spaces, extra_qualifier):
"""As above, but now for C"""
indent = " " * (spaces + self.length())
- result = "StatusCode" + extra_qualifier + " CLBlast" + flavour.name + self.name + "("
- result += (",\n" + indent).join([a for a in self.arguments_def(flavour)])
+ result = "CLBlastStatusCode" + extra_qualifier + " CLBlast" + flavour.name + self.name + "("
+ result += (",\n" + indent).join([a for a in self.arguments_def_c(flavour)])
result += ",\n" + indent + "cl_command_queue* queue, cl_event* event)"
return result