diff options
Diffstat (limited to 'scripts/generator/routine.py')
-rw-r--r-- | scripts/generator/routine.py | 49 |
1 files changed, 30 insertions, 19 deletions
diff --git a/scripts/generator/routine.py b/scripts/generator/routine.py index 1086cecc..d74def25 100644 --- a/scripts/generator/routine.py +++ b/scripts/generator/routine.py @@ -39,9 +39,6 @@ def OptionToWrapper(x): 'diagonal': "clblasDiag", }[x] -# Buffers without 'ld' or 'inc' parameter -NO_LD_INC = ["dot","ap"] - # ================================================================================================== # Class holding routine-specific information (e.g. name, which arguments, which precisions) @@ -61,6 +58,14 @@ class Routine(): self.scratch = scratch # Scratch buffer (e.g. for xDOT) self.description = description + # List of scalar buffers + def ScalarBuffers(self): + return ["SA","SB","C","S","dot"] + + # List of buffers without 'inc' or 'ld' + def BuffersWithoutLdInc(self): + return self.ScalarBuffers() + ["ap"] + # Retrieves the number of characters in the routine's name def Length(self): return len(self.name) @@ -94,7 +99,7 @@ class Routine(): if (name in self.inputs) or (name in self.outputs): a = [name+"_buffer"] b = [name+"_offset"] - c = [name+"_"+self.Postfix(name)] if (name not in NO_LD_INC) else [] + c = [name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else [] return [", ".join(a+b+c)] return [] @@ -104,7 +109,7 @@ class Routine(): if (name in self.inputs) or (name in self.outputs): a = [prefix+"cl_mem "+name+"_buffer"] b = ["const size_t "+name+"_offset"] - c = ["const size_t "+name+"_"+self.Postfix(name)] if (name not in NO_LD_INC) else [] + c = ["const size_t "+name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else [] return [", ".join(a+b+c)] return [] @@ -113,7 +118,7 @@ class Routine(): if (name in self.inputs) or (name in self.outputs): a = ["Buffer<"+self.template.buffertype+">("+name+"_buffer)"] b = [name+"_offset"] - c = [name+"_"+self.Postfix(name)] if (name not in NO_LD_INC) else [] + c = [name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else [] return [", ".join(a+b+c)] return [] @@ -136,7 +141,7 @@ class Routine(): if (name in self.inputs) or (name in self.outputs): a = [prefix+"cl_mem"] b = ["const size_t"] - c = ["const size_t"] if (name not in NO_LD_INC) else [] + c = ["const size_t"] if (name not in self.BuffersWithoutLdInc()) else [] return [", ".join(a+b+c)] return [] @@ -252,57 +257,63 @@ class Routine(): # Retrieves a combination of all the argument names, with Claduc casts def ArgumentsCladuc(self, flavour, indent): - return (self.Options() + self.Sizes() + self.BufferCladuc("dot") + + return (self.Options() + self.Sizes() + + list(chain(*[self.BufferCladuc(b) for b in self.ScalarBuffers()])) + self.Scalar("alpha") + list(chain(*[self.BufferCladuc(b) for b in self.BuffersFirst()])) + self.Scalar("beta") + list(chain(*[self.BufferCladuc(b) for b in self.BuffersSecond()])) + - list(chain(*[self.Scalar(s) for s in ["d1","d2","a","b","c","s"]]))) + list(chain(*[self.Scalar(s) for s in ["C","S"]]))) # Retrieves a combination of all the argument names, with CLBlast casts def ArgumentsCast(self, flavour, indent): - return (self.OptionsCast(indent) + self.Sizes() + self.Buffer("dot") + + return (self.OptionsCast(indent) + self.Sizes() + + list(chain(*[self.Buffer(b) for b in self.ScalarBuffers()])) + self.ScalarUse("alpha", flavour) + list(chain(*[self.Buffer(b) for b in self.BuffersFirst()])) + self.ScalarUse("beta", flavour) + list(chain(*[self.Buffer(b) for b in self.BuffersSecond()])) + - list(chain(*[self.ScalarUse(s, flavour) for s in ["d1","d2","a","b","c","s"]]))) + list(chain(*[self.ScalarUse(s, flavour) for s in ["C","S"]]))) # As above, but for the clBLAS wrapper def ArgumentsWrapper(self, flavour): - return (self.Options() + self.Sizes() + self.BufferWrapper("dot") + + return (self.Options() + self.Sizes() + + list(chain(*[self.BufferWrapper(b) for b in self.ScalarBuffers()])) + self.ScalarUseWrapper("alpha", flavour) + list(chain(*[self.BufferWrapper(b) for b in self.BuffersFirst()])) + self.ScalarUseWrapper("beta", flavour) + list(chain(*[self.BufferWrapper(b) for b in self.BuffersSecond()])) + - list(chain(*[self.ScalarUseWrapper(s, flavour) for s in ["d1","d2","a","b","c","s"]]))) + list(chain(*[self.ScalarUseWrapper(s, flavour) for s in ["C","S"]]))) # Retrieves a combination of all the argument definitions def ArgumentsDef(self, flavour): - return (self.OptionsDef() + self.SizesDef() + self.BufferDef("dot") + + return (self.OptionsDef() + self.SizesDef() + + list(chain(*[self.BufferDef(b) for b in self.ScalarBuffers()])) + self.ScalarDef("alpha", flavour) + list(chain(*[self.BufferDef(b) for b in self.BuffersFirst()])) + self.ScalarDef("beta", flavour) + list(chain(*[self.BufferDef(b) for b in self.BuffersSecond()])) + - list(chain(*[self.ScalarDef(s, flavour) for s in ["d1","d2","a","b","c","s"]]))) + list(chain(*[self.ScalarDef(s, flavour) for s in ["C","S"]]))) # As above, but clBLAS wrapper plain datatypes def ArgumentsDefWrapper(self, flavour): - return (self.OptionsDefWrapper() + self.SizesDef() + self.BufferDef("dot") + + return (self.OptionsDefWrapper() + self.SizesDef() + + list(chain(*[self.BufferDef(b) for b in self.ScalarBuffers()])) + self.ScalarDefPlain("alpha", flavour) + list(chain(*[self.BufferDef(b) for b in self.BuffersFirst()])) + self.ScalarDefPlain("beta", flavour) + list(chain(*[self.BufferDef(b) for b in self.BuffersSecond()])) + - list(chain(*[self.ScalarDefPlain(s, flavour) for s in ["d1","d2","a","b","c","s"]]))) + list(chain(*[self.ScalarDefPlain(s, flavour) for s in ["C","S"]]))) # Retrieves a combination of all the argument types def ArgumentsType(self, flavour): - return (self.OptionsType() + self.SizesType() + self.BufferType("dot") + + return (self.OptionsType() + self.SizesType() + + list(chain(*[self.BufferType(b) for b in self.ScalarBuffers()])) + self.ScalarType("alpha", flavour) + list(chain(*[self.BufferType(b) for b in self.BuffersFirst()])) + self.ScalarType("beta", flavour) + list(chain(*[self.BufferType(b) for b in self.BuffersSecond()])) + - list(chain(*[self.ScalarType(s, flavour) for s in ["d1","d2","a","b","c","s"]]))) + list(chain(*[self.ScalarType(s, flavour) for s in ["C","S"]]))) # ============================================================================================== |