summaryrefslogtreecommitdiff
path: root/scripts/generator/routine.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/routine.py')
-rw-r--r--scripts/generator/routine.py75
1 files changed, 74 insertions, 1 deletions
diff --git a/scripts/generator/routine.py b/scripts/generator/routine.py
index 95681da6..e5059c61 100644
--- a/scripts/generator/routine.py
+++ b/scripts/generator/routine.py
@@ -51,12 +51,24 @@ def OptionToWrapperC(x):
'diagonal': "CBLAS_DIAG",
}[x]
+# Translates an option name to a documentation string
+def OptionToDoc(x):
+ return {
+ 'layout': "Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.",
+ 'a_transpose': "Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
+ 'b_transpose': "Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
+ 'ab_transpose': "Transposing the packed input matrix AP, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
+ 'side': "The horizontal position of the triangular matrix, either `Side::kLeft` (141) or `Side::kRight` (142).",
+ 'triangle': "The vertical position of the triangular matrix, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
+ 'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for a non-unit values on the diagonal or `Diagonal::kUnit` (132) for a unit values on the diagonal.",
+ }[x]
+
# ==================================================================================================
# Class holding routine-specific information (e.g. name, which arguments, which precisions)
class Routine():
def __init__(self, implemented, has_tests, level, name, template, flavours, sizes, options,
- inputs, outputs, scalars, scratch, description):
+ inputs, outputs, scalars, scratch, description, details, requirements):
self.implemented = implemented
self.has_tests = has_tests
self.level = level
@@ -70,6 +82,8 @@ class Routine():
self.scalars = scalars
self.scratch = scratch # Scratch buffer (e.g. for xDOT)
self.description = description
+ self.details = details
+ self.requirements = requirements
# List of scalar buffers
def ScalarBuffersFirst(self):
@@ -115,6 +129,12 @@ class Routine():
return ["ap","a","b","c"]
return ["y","c"]
+ # Distinguish between vectors and matrices
+ def BuffersVector(self):
+ return ["x","y"]
+ def BuffersMatrix(self):
+ return ["a","b","c","ap"]
+
# ==============================================================================================
# Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x')
@@ -197,6 +217,19 @@ class Routine():
return [", ".join(a+b+c)]
return []
+ # Retrieves the documentation of the buffers
+ def BufferDoc(self, name):
+ prefix = "const " if (name in self.inputs) else ""
+ inout = "input" if (name in self.inputs) else "output"
+ if (name in self.inputs) or (name in self.outputs):
+ math_name = name.upper()+" matrix" if (name in self.BuffersMatrix()) else name+" vector"
+ incld_description = "Leading dimension " if (name in self.BuffersMatrix()) else "Stride/increment "
+ a = ["`"+prefix+"cl_mem "+name+"_buffer`: OpenCL buffer to store the "+inout+" "+math_name+"."]
+ b = ["`const size_t "+name+"_offset`: The offset in elements from the start of the "+inout+" "+math_name+"."]
+ c = ["`const size_t "+name+"_"+self.Postfix(name)+"`: "+incld_description+"of the "+inout+" "+math_name+"."] if (name not in self.BuffersWithoutLdInc()) else []
+ return a+b+c
+ return []
+
# ==============================================================================================
# Retrieves the name of a scalar (alpha/beta)
@@ -257,6 +290,14 @@ class Routine():
return ["const "+flavour.beta_cpp]
return []
+ # Retrieves the documentation of a scalar
+ def ScalarDoc(self, name):
+ if name in self.scalars:
+ if name == "alpha":
+ return ["`const "+self.template.alpha_cpp+" "+name+"`: Input scalar constant."]
+ return ["`const "+self.template.beta_cpp+" "+name+"`: Input scalar constant."]
+ return []
+
# ==============================================================================================
# Retrieves a list of comma-separated sizes (m, n, k)
@@ -277,6 +318,13 @@ class Routine():
return [", ".join(["const size_t" for s in self.sizes])]
return []
+ # Retrieves the documentation of the sizes
+ def SizesDoc(self):
+ if self.sizes:
+ definitions = ["`const size_t "+s+"`: Integer size argument." for s in self.sizes]
+ return definitions
+ return []
+
# ==============================================================================================
# Retrieves a list of options
@@ -320,6 +368,13 @@ class Routine():
return [", ".join(definitions)]
return []
+ # Retrieves the documentation of the options
+ def OptionsDoc(self):
+ if self.options:
+ definitions = ["`const "+OptionToCLBlast(o)+"`: "+OptionToDoc(o) for o in self.options]
+ return definitions
+ return []
+
# ==============================================================================================
# Retrieves a combination of all the argument names, with Claduc casts
@@ -408,6 +463,24 @@ class Routine():
list(chain(*[self.BufferType(b) for b in self.BuffersSecond()])) +
list(chain(*[self.BufferType(b) for b in self.ScalarBuffersSecond()])) +
list(chain(*[self.ScalarType(s, flavour) for s in self.OtherScalars()])))
+
+ # Retrieves a combination of all the argument types
+ def ArgumentsDoc(self):
+ return (self.OptionsDoc() + self.SizesDoc() +
+ list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersFirst()])) +
+ list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersFirst()])) +
+ self.ScalarDoc("alpha") +
+ list(chain(*[self.BufferDoc(b) for b in self.BuffersFirst()])) +
+ self.ScalarDoc("beta") +
+ list(chain(*[self.BufferDoc(b) for b in self.BuffersSecond()])) +
+ list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersSecond()])) +
+ list(chain(*[self.ScalarDoc(s) for s in self.OtherScalars()])))
+
+ # ==============================================================================================
+
+ # Retrieves a list of routine requirements for documentation
+ def RequirementsDoc(self):
+ return []
# ==============================================================================================