diff options
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r-- | scripts/generator/generator/routine.py | 110 |
1 files changed, 108 insertions, 2 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 59b2ed73..1c534611 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -197,6 +197,10 @@ class Routine: """Determines whether or not this routine has scalar arguments (alpha/beta)""" return self.scalars == [] + def has_layout(self): + """Determines whether the layout is an argument""" + return "layout" in self.options + def short_names(self): """Returns the upper-case names of these routines (all flavours)""" return "/".join([f.name + self.upper_name() for f in self.flavours]) @@ -257,7 +261,7 @@ class Routine: return [] def buffer_def_wrapper_cl(self, name, flavour): - """As above but with data-types""" + """As above but for OpenCL""" prefix = "const " if name in self.inputs else "" if name in self.inputs or name in self.outputs: a = [prefix + "Buffer<" + flavour.buffer_type + ">& " + name + "_buffer"] @@ -266,6 +270,16 @@ class Routine: return [", ".join(a + b + c)] return [] + def buffer_def_wrapper_cuda(self, name, flavour): + """As above but for CUDA""" + prefix = "const " if name in self.inputs else "" + if name in self.inputs or name in self.outputs: + a = [prefix + flavour.buffer_type + "* " + name + "_buffer"] + b = ["const size_t " + name + "_offset"] + c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else [] + return [", ".join(a + b + c)] + return [] + def buffer_def_vector(self, name, flavour): """As above but as vectors""" prefix = "const " if name in self.inputs else "" @@ -329,6 +343,32 @@ class Routine: return [", ".join(a + c)] return [] + def buffer_wrapper_cublas(self, name, flavour): + """As above but for cuBLAS the wrapper""" + prefix = "const " if name in self.inputs else "" + if name in self.inputs or name in self.outputs: + if name in self.index_buffers(): + a = ["reinterpret_cast<int*>(&" + name + "_buffer[" + name + "_offset])"] + elif name in self.outputs and flavour.name in ["Sc", "Dz"]: + dtype = "float" if flavour.name == "Sc" else "double" + a = ["reinterpret_cast<" + dtype + "*>(&" + name + "_buffer[" + name + "_offset])"] + elif flavour.precision_name in ["C", "Z"]: + cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex" + a = ["reinterpret_cast<" + prefix + cuda_complex + "*>" + + "(&" + name + "_buffer[" + name + "_offset])"] + else: + a = ["&" + name + "_buffer[" + name + "_offset]"] + c = [] + if name in ["x", "y"]: + c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"] + elif name in ["a", "b", "c"]: + c = [name + "_" + self.postfix(name)] + result = [", ".join(a + c)] + if self.name == "trmm" and name == "a": + result *= 2 + return result + return [] + def buffer_type(self, name): """As above, but only data-types""" prefix = "const " if (name in self.inputs) else "" @@ -407,6 +447,14 @@ class Routine: return [name] return [] + def scalar_use_wrapper_cublas(self, name, flavour): + """As above, but for the cuBLAS wrapper""" + if name in self.scalars: + if flavour.is_complex(name): + return ["&" + name + "_cuda"] + return ["&" + name] + return [] + def scalar_def(self, name, flavour): """Retrieves the definition of a scalar (alpha/beta)""" if name in self.scalars: @@ -465,6 +513,12 @@ class Routine: return [", ".join([s for s in self.sizes])] return [] + def sizes_list_as_int(self): + """Retrieves a list of comma-separated sizes (m, n, k) cast to integers""" + if self.sizes: + return [", ".join(["static_cast<int>(" + s + ")" for s in self.sizes])] + return [] + def sizes_def(self): """Retrieves the definition of the sizes (m,n,k)""" if self.sizes: @@ -496,6 +550,15 @@ class Routine: return [", ".join(self.options)] return [] + def options_list_no_layout(self): + """Retrieves a list of options""" + options = self.options[:] + if "layout" in options: + options.remove("layout") + if options: + return [", ".join(options)] + return [] + def options_cast(self, indent): """As above, but now casted to CLBlast data-types""" if self.options: @@ -531,6 +594,13 @@ class Routine: return [", ".join(definitions)] return [] + def options_def_wrapper_cublas(self): + """As above, but now using cuBLAS data-types""" + if self.options: + definitions = ["const " + convert.option_to_cublas(o) + " " + o for o in self.options] + return [", ".join(definitions)] + return [] + def options_type(self): """Retrieves the types of the options (layout, transpose, side, etc.)""" if self.options: @@ -615,7 +685,7 @@ class Routine: def arguments_wrapper_cblas(self, flavour): """As above, but for the CBLAS wrapper""" - return (self.options_list() + self.sizes_list() + + return (self.options_list() + self.sizes_list_as_int() + self.scalar_use_wrapper_cblas("alpha", flavour) + list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.buffers_first()])) + self.scalar_use_wrapper_cblas("beta", flavour) + @@ -623,6 +693,17 @@ class Routine: list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.scalar_buffers_second()])) + list(chain(*[self.scalar_use_wrapper_cblas(s, flavour) for s in self.other_scalars()]))) + def arguments_wrapper_cublas(self, flavour): + """As above, but for the cuBLAS wrapper""" + return (self.options_list_no_layout() + self.sizes_list_as_int() + + self.scalar_use_wrapper_cublas("alpha", flavour) + + list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_first()])) + + self.scalar_use_wrapper_cublas("beta", flavour) + + list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_second()])) + + list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_first()])) + + list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_second()])) + + list(chain(*[self.scalar_use_wrapper_cublas(s, flavour) for s in self.other_scalars()]))) + def arguments_def(self, flavour): """Retrieves a combination of all the argument definitions""" return (self.options_def() + self.sizes_def() + @@ -683,6 +764,17 @@ class Routine: list(chain(*[self.buffer_def_vector(b, flavour) for b in self.scalar_buffers_second()])) + list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()]))) + def arguments_def_wrapper_cublas(self, flavour): + """As above, but cuBLAS wrapper plain data-types""" + return (self.options_def_wrapper_cublas() + self.sizes_def() + + list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.scalar_buffers_first()])) + + self.scalar_def_plain("alpha", flavour) + + list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.buffers_first()])) + + self.scalar_def_plain("beta", flavour) + + list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.buffers_second()])) + + list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.scalar_buffers_second()])) + + list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()]))) + def arguments_type(self, flavour): """Retrieves a combination of all the argument types""" return (self.options_type() + self.sizes_type() + @@ -781,3 +873,17 @@ class Routine: result = "void cblasX" + self.name + "(" result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_cblas(flavour)]) + ")" return result + + def routine_header_wrapper_cublas(self, flavour, def_only, spaces): + """As above, but now for the cuBLAS wrapper""" + template = "<" + flavour.template + ">" if self.no_scalars() and not def_only else "" + indent = " " * (spaces + self.length() + len(template)) + result = "" + if self.no_scalars(): + result += "template <" + if def_only: + result += flavour.name + result += ">\n" + result += "cublasStatus_t cublasX" + self.name + template + "(cublasHandle_t handle, " + result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_cublas(flavour)]) + ")" + return result |