diff options
Diffstat (limited to 'scripts/generator/routine.py')
-rw-r--r-- | scripts/generator/routine.py | 75 |
1 files changed, 74 insertions, 1 deletions
diff --git a/scripts/generator/routine.py b/scripts/generator/routine.py index 95681da6..e5059c61 100644 --- a/scripts/generator/routine.py +++ b/scripts/generator/routine.py @@ -51,12 +51,24 @@ def OptionToWrapperC(x): 'diagonal': "CBLAS_DIAG", }[x] +# Translates an option name to a documentation string +def OptionToDoc(x): + return { + 'layout': "Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.", + 'a_transpose': "Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.", + 'b_transpose': "Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.", + 'ab_transpose': "Transposing the packed input matrix AP, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.", + 'side': "The horizontal position of the triangular matrix, either `Side::kLeft` (141) or `Side::kRight` (142).", + 'triangle': "The vertical position of the triangular matrix, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).", + 'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for a non-unit values on the diagonal or `Diagonal::kUnit` (132) for a unit values on the diagonal.", + }[x] + # ================================================================================================== # Class holding routine-specific information (e.g. name, which arguments, which precisions) class Routine(): def __init__(self, implemented, has_tests, level, name, template, flavours, sizes, options, - inputs, outputs, scalars, scratch, description): + inputs, outputs, scalars, scratch, description, details, requirements): self.implemented = implemented self.has_tests = has_tests self.level = level @@ -70,6 +82,8 @@ class Routine(): self.scalars = scalars self.scratch = scratch # Scratch buffer (e.g. for xDOT) self.description = description + self.details = details + self.requirements = requirements # List of scalar buffers def ScalarBuffersFirst(self): @@ -115,6 +129,12 @@ class Routine(): return ["ap","a","b","c"] return ["y","c"] + # Distinguish between vectors and matrices + def BuffersVector(self): + return ["x","y"] + def BuffersMatrix(self): + return ["a","b","c","ap"] + # ============================================================================================== # Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x') @@ -197,6 +217,19 @@ class Routine(): return [", ".join(a+b+c)] return [] + # Retrieves the documentation of the buffers + def BufferDoc(self, name): + prefix = "const " if (name in self.inputs) else "" + inout = "input" if (name in self.inputs) else "output" + if (name in self.inputs) or (name in self.outputs): + math_name = name.upper()+" matrix" if (name in self.BuffersMatrix()) else name+" vector" + incld_description = "Leading dimension " if (name in self.BuffersMatrix()) else "Stride/increment " + a = ["`"+prefix+"cl_mem "+name+"_buffer`: OpenCL buffer to store the "+inout+" "+math_name+"."] + b = ["`const size_t "+name+"_offset`: The offset in elements from the start of the "+inout+" "+math_name+"."] + c = ["`const size_t "+name+"_"+self.Postfix(name)+"`: "+incld_description+"of the "+inout+" "+math_name+"."] if (name not in self.BuffersWithoutLdInc()) else [] + return a+b+c + return [] + # ============================================================================================== # Retrieves the name of a scalar (alpha/beta) @@ -257,6 +290,14 @@ class Routine(): return ["const "+flavour.beta_cpp] return [] + # Retrieves the documentation of a scalar + def ScalarDoc(self, name): + if name in self.scalars: + if name == "alpha": + return ["`const "+self.template.alpha_cpp+" "+name+"`: Input scalar constant."] + return ["`const "+self.template.beta_cpp+" "+name+"`: Input scalar constant."] + return [] + # ============================================================================================== # Retrieves a list of comma-separated sizes (m, n, k) @@ -277,6 +318,13 @@ class Routine(): return [", ".join(["const size_t" for s in self.sizes])] return [] + # Retrieves the documentation of the sizes + def SizesDoc(self): + if self.sizes: + definitions = ["`const size_t "+s+"`: Integer size argument." for s in self.sizes] + return definitions + return [] + # ============================================================================================== # Retrieves a list of options @@ -320,6 +368,13 @@ class Routine(): return [", ".join(definitions)] return [] + # Retrieves the documentation of the options + def OptionsDoc(self): + if self.options: + definitions = ["`const "+OptionToCLBlast(o)+"`: "+OptionToDoc(o) for o in self.options] + return definitions + return [] + # ============================================================================================== # Retrieves a combination of all the argument names, with Claduc casts @@ -408,6 +463,24 @@ class Routine(): list(chain(*[self.BufferType(b) for b in self.BuffersSecond()])) + list(chain(*[self.BufferType(b) for b in self.ScalarBuffersSecond()])) + list(chain(*[self.ScalarType(s, flavour) for s in self.OtherScalars()]))) + + # Retrieves a combination of all the argument types + def ArgumentsDoc(self): + return (self.OptionsDoc() + self.SizesDoc() + + list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersFirst()])) + + list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersFirst()])) + + self.ScalarDoc("alpha") + + list(chain(*[self.BufferDoc(b) for b in self.BuffersFirst()])) + + self.ScalarDoc("beta") + + list(chain(*[self.BufferDoc(b) for b in self.BuffersSecond()])) + + list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersSecond()])) + + list(chain(*[self.ScalarDoc(s) for s in self.OtherScalars()]))) + + # ============================================================================================== + + # Retrieves a list of routine requirements for documentation + def RequirementsDoc(self): + return [] # ============================================================================================== |