diff options
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r-- | scripts/generator/generator/routine.py | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index a4e682c2..2fa5e9d6 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -349,6 +349,13 @@ class Routine: return [", ".join(definitions)] return [] + def options_def_c(self): + """As above, but now for the C API""" + if self.options: + definitions = ["const CLBlast" + convert.option_to_clblast(o) + " " + o for o in self.options] + return [", ".join(definitions)] + return [] + def options_def_wrapper_clblas(self): """As above, but now using clBLAS data-types""" if self.options: @@ -453,6 +460,17 @@ class Routine: list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) + list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()]))) + def arguments_def_c(self, flavour): + """As above, but for the C API""" + return (self.options_def_c() + self.sizes_def() + + list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_first()])) + + self.scalar_def("alpha", flavour) + + list(chain(*[self.buffer_def(b) for b in self.buffers_first()])) + + self.scalar_def("beta", flavour) + + list(chain(*[self.buffer_def(b) for b in self.buffers_second()])) + + list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) + + list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()]))) + def arguments_def_wrapper_clblas(self, flavour): """As above, but clBLAS wrapper plain data-types""" return (self.options_def_wrapper_clblas() + self.sizes_def() + @@ -523,8 +541,8 @@ class Routine: def routine_header_c(self, flavour, spaces, extra_qualifier): """As above, but now for C""" indent = " " * (spaces + self.length()) - result = "StatusCode" + extra_qualifier + " CLBlast" + flavour.name + self.name + "(" - result += (",\n" + indent).join([a for a in self.arguments_def(flavour)]) + result = "CLBlastStatusCode" + extra_qualifier + " CLBlast" + flavour.name + self.name + "(" + result += (",\n" + indent).join([a for a in self.arguments_def_c(flavour)]) result += ",\n" + indent + "cl_command_queue* queue, cl_event* event)" return result |