summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/routine.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r--scripts/generator/generator/routine.py110
1 files changed, 108 insertions, 2 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 59b2ed73..1c534611 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -197,6 +197,10 @@ class Routine:
"""Determines whether or not this routine has scalar arguments (alpha/beta)"""
return self.scalars == []
+ def has_layout(self):
+ """Determines whether the layout is an argument"""
+ return "layout" in self.options
+
def short_names(self):
"""Returns the upper-case names of these routines (all flavours)"""
return "/".join([f.name + self.upper_name() for f in self.flavours])
@@ -257,7 +261,7 @@ class Routine:
return []
def buffer_def_wrapper_cl(self, name, flavour):
- """As above but with data-types"""
+ """As above but for OpenCL"""
prefix = "const " if name in self.inputs else ""
if name in self.inputs or name in self.outputs:
a = [prefix + "Buffer<" + flavour.buffer_type + ">& " + name + "_buffer"]
@@ -266,6 +270,16 @@ class Routine:
return [", ".join(a + b + c)]
return []
+ def buffer_def_wrapper_cuda(self, name, flavour):
+ """As above but for CUDA"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ a = [prefix + flavour.buffer_type + "* " + name + "_buffer"]
+ b = ["const size_t " + name + "_offset"]
+ c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ return [", ".join(a + b + c)]
+ return []
+
def buffer_def_vector(self, name, flavour):
"""As above but as vectors"""
prefix = "const " if name in self.inputs else ""
@@ -329,6 +343,32 @@ class Routine:
return [", ".join(a + c)]
return []
+ def buffer_wrapper_cublas(self, name, flavour):
+ """As above but for cuBLAS the wrapper"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ if name in self.index_buffers():
+ a = ["reinterpret_cast<int*>(&" + name + "_buffer[" + name + "_offset])"]
+ elif name in self.outputs and flavour.name in ["Sc", "Dz"]:
+ dtype = "float" if flavour.name == "Sc" else "double"
+ a = ["reinterpret_cast<" + dtype + "*>(&" + name + "_buffer[" + name + "_offset])"]
+ elif flavour.precision_name in ["C", "Z"]:
+ cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex"
+ a = ["reinterpret_cast<" + prefix + cuda_complex + "*>" +
+ "(&" + name + "_buffer[" + name + "_offset])"]
+ else:
+ a = ["&" + name + "_buffer[" + name + "_offset]"]
+ c = []
+ if name in ["x", "y"]:
+ c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"]
+ elif name in ["a", "b", "c"]:
+ c = [name + "_" + self.postfix(name)]
+ result = [", ".join(a + c)]
+ if self.name == "trmm" and name == "a":
+ result *= 2
+ return result
+ return []
+
def buffer_type(self, name):
"""As above, but only data-types"""
prefix = "const " if (name in self.inputs) else ""
@@ -407,6 +447,14 @@ class Routine:
return [name]
return []
+ def scalar_use_wrapper_cublas(self, name, flavour):
+ """As above, but for the cuBLAS wrapper"""
+ if name in self.scalars:
+ if flavour.is_complex(name):
+ return ["&" + name + "_cuda"]
+ return ["&" + name]
+ return []
+
def scalar_def(self, name, flavour):
"""Retrieves the definition of a scalar (alpha/beta)"""
if name in self.scalars:
@@ -465,6 +513,12 @@ class Routine:
return [", ".join([s for s in self.sizes])]
return []
+ def sizes_list_as_int(self):
+ """Retrieves a list of comma-separated sizes (m, n, k) cast to integers"""
+ if self.sizes:
+ return [", ".join(["static_cast<int>(" + s + ")" for s in self.sizes])]
+ return []
+
def sizes_def(self):
"""Retrieves the definition of the sizes (m,n,k)"""
if self.sizes:
@@ -496,6 +550,15 @@ class Routine:
return [", ".join(self.options)]
return []
+ def options_list_no_layout(self):
+ """Retrieves a list of options"""
+ options = self.options[:]
+ if "layout" in options:
+ options.remove("layout")
+ if options:
+ return [", ".join(options)]
+ return []
+
def options_cast(self, indent):
"""As above, but now casted to CLBlast data-types"""
if self.options:
@@ -531,6 +594,13 @@ class Routine:
return [", ".join(definitions)]
return []
+ def options_def_wrapper_cublas(self):
+ """As above, but now using cuBLAS data-types"""
+ if self.options:
+ definitions = ["const " + convert.option_to_cublas(o) + " " + o for o in self.options]
+ return [", ".join(definitions)]
+ return []
+
def options_type(self):
"""Retrieves the types of the options (layout, transpose, side, etc.)"""
if self.options:
@@ -615,7 +685,7 @@ class Routine:
def arguments_wrapper_cblas(self, flavour):
"""As above, but for the CBLAS wrapper"""
- return (self.options_list() + self.sizes_list() +
+ return (self.options_list() + self.sizes_list_as_int() +
self.scalar_use_wrapper_cblas("alpha", flavour) +
list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.buffers_first()])) +
self.scalar_use_wrapper_cblas("beta", flavour) +
@@ -623,6 +693,17 @@ class Routine:
list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.scalar_buffers_second()])) +
list(chain(*[self.scalar_use_wrapper_cblas(s, flavour) for s in self.other_scalars()])))
+ def arguments_wrapper_cublas(self, flavour):
+ """As above, but for the cuBLAS wrapper"""
+ return (self.options_list_no_layout() + self.sizes_list_as_int() +
+ self.scalar_use_wrapper_cublas("alpha", flavour) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_use_wrapper_cublas("beta", flavour) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_first()])) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_use_wrapper_cublas(s, flavour) for s in self.other_scalars()])))
+
def arguments_def(self, flavour):
"""Retrieves a combination of all the argument definitions"""
return (self.options_def() + self.sizes_def() +
@@ -683,6 +764,17 @@ class Routine:
list(chain(*[self.buffer_def_vector(b, flavour) for b in self.scalar_buffers_second()])) +
list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()])))
+ def arguments_def_wrapper_cublas(self, flavour):
+ """As above, but cuBLAS wrapper plain data-types"""
+ return (self.options_def_wrapper_cublas() + self.sizes_def() +
+ list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.scalar_buffers_first()])) +
+ self.scalar_def_plain("alpha", flavour) +
+ list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_def_plain("beta", flavour) +
+ list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()])))
+
def arguments_type(self, flavour):
"""Retrieves a combination of all the argument types"""
return (self.options_type() + self.sizes_type() +
@@ -781,3 +873,17 @@ class Routine:
result = "void cblasX" + self.name + "("
result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_cblas(flavour)]) + ")"
return result
+
+ def routine_header_wrapper_cublas(self, flavour, def_only, spaces):
+ """As above, but now for the cuBLAS wrapper"""
+ template = "<" + flavour.template + ">" if self.no_scalars() and not def_only else ""
+ indent = " " * (spaces + self.length() + len(template))
+ result = ""
+ if self.no_scalars():
+ result += "template <"
+ if def_only:
+ result += flavour.name
+ result += ">\n"
+ result += "cublasStatus_t cublasX" + self.name + template + "(cublasHandle_t handle, "
+ result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_cublas(flavour)]) + ")"
+ return result