diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-11-22 08:41:52 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-11-22 08:41:52 +0100 |
commit | 26ca07148092b5d4fcb0e25190e07bf6acae25a3 (patch) | |
tree | 81854684aa03e09ad471228024f6c986b1b60f55 /scripts/generator/generator | |
parent | eefe0df43575686c6aa48a9fb6e25e27bef1af40 (diff) |
Minor changes to ensure full compatibility with the Netlib CBLAS API
Diffstat (limited to 'scripts/generator/generator')
-rw-r--r-- | scripts/generator/generator/cpp.py | 7 | ||||
-rw-r--r-- | scripts/generator/generator/routine.py | 33 |
2 files changed, 31 insertions, 9 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 7b7ece22..6bb3080f 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -112,6 +112,7 @@ def clblast_netlib_c_cc(routine): # There is a version available in CBLAS if flavour.precision_name in ["S", "D", "C", "Z"]: template = "<" + flavour.template + ">" if routine.no_scalars() else "" + name_postfix = "_sub" if routine.name in routine.routines_scalar_no_return() else "" indent = " " * (21 + routine.length() + len(template)) result += routine.routine_header_netlib(flavour, 9, "") + " {" + NL @@ -129,6 +130,8 @@ def clblast_netlib_c_cc(routine): for i, name in enumerate(routine.inputs + routine.outputs): buffer_type = routine.get_buffer_type(name, flavour) result += " " + routine.create_buffer(name, buffer_type) + NL + if name in routine.scalar_buffers_second_non_pointer(): + result += " " + buffer_type + " " + name + "_vec[1]; " + name + "_vec[0] = " + name + ";" + NL for name in routine.inputs + routine.outputs: if name not in routine.scalar_buffers_first(): prefix = "" if name in routine.outputs else "const " @@ -148,14 +151,14 @@ def clblast_netlib_c_cc(routine): # Copy back and clean-up for name in routine.outputs: - if name in routine.scalar_buffers_first(): + if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return(): buffer_type = routine.get_buffer_type(name, flavour) result += " " + buffer_type + " " + name + "[" + name + "_size];" + NL for name in routine.outputs: buffer_type = routine.get_buffer_type(name, flavour) result += " " + routine.read_buffer(name, buffer_type) + NL for name in routine.outputs: - if name in routine.scalar_buffers_first(): + if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return(): result += " return " + name + "[0]" if flavour.buffer_type in ["float2", "double2"]: if name not in routine.index_buffers(): diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 391cf3e0..6fcce23b 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -43,6 +43,11 @@ class Routine: return ["sa", "sb", "sc", "ss", "sd1", "sd2", "sx1", "sy1", "sparam"] @staticmethod + def scalar_buffers_second_non_pointer(): + """As above, but these ones are not passed as pointers but as scalars instead""" + return ["sy1"] + + @staticmethod def other_scalars(): """List of scalars other than alpha and beta""" return ["cos", "sin"] @@ -68,6 +73,10 @@ class Routine: return ["a", "b", "c", "ap"] @staticmethod + def routines_scalar_no_return(): + return ["dotu", "dotc"] + + @staticmethod def set_size(name, size): """Sets the size of a buffer""" return "const auto " + name + "_size = " + size + ";" @@ -77,10 +86,12 @@ class Routine: """Creates a new CLCudaAPI buffer""" return "auto " + name + "_buffer = clblast::Buffer<" + template + ">(context, " + name + "_size);" - @staticmethod - def write_buffer(name, template): + def write_buffer(self, name, template): """Writes to a CLCudaAPI buffer""" - data_structure = "reinterpret_cast<" + template + "*>(" + name + ")" + postfix = "" + if name in self.scalar_buffers_second_non_pointer(): + postfix = "_vec" + data_structure = "reinterpret_cast<" + template + "*>(" + name + postfix + ")" return name + "_buffer.Write(queue, " + name + "_size, " + data_structure + ");" @staticmethod @@ -206,7 +217,8 @@ class Routine: prefix = "const " if name in self.inputs else "" if name in self.inputs or name in self.outputs: data_type = "void" if flavour.is_non_standard() else flavour.buffer_type - a = [prefix + data_type + "* " + name + ""] + pointer = "" if name in self.scalar_buffers_second_non_pointer() else "*" + a = [prefix + data_type + pointer + " " + name + ""] c = ["const int " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else [] return [", ".join(a + c)] return [] @@ -553,13 +565,16 @@ class Routine: def arguments_def_netlib(self, flavour): """As above, but for the Netlib CBLAS API""" - return (self.options_def_c() + self.sizes_def_netlib() + + result=(self.options_def_c() + self.sizes_def_netlib() + self.scalar_def_void("alpha", flavour) + list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.buffers_first()])) + self.scalar_def_void("beta", flavour) + list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.buffers_second()])) + list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.scalar_buffers_second()])) + list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()]))) + if self.name in self.routines_scalar_no_return(): + result += list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.scalar_buffers_first()])) + return result def arguments_def_c(self, flavour): """As above, but for the C API""" @@ -654,11 +669,15 @@ class Routine: if output in self.index_buffers(): return_type = "int" break - if output in self.scalar_buffers_first(): + if output in self.scalar_buffers_first() and self.name not in self.routines_scalar_no_return(): return_type = flavour.buffer_type.replace("2", "") break indent = " " * (spaces + len(return_type) + self.length()) - result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + self.name + "(" + routine_name = self.name + if self.name in self.routines_scalar_no_return(): + routine_name += "_sub" + indent += " " + result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + routine_name + "(" result += (",\n" + indent).join([a for a in self.arguments_def_netlib(flavour)]) + ")" return result |