diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-12 19:56:21 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-12 19:56:21 +0200 |
commit | f2ba75890c522b4fe1762bfeac3e08667cf9588a (patch) | |
tree | 82e22cb72fbfb135570ce3bf3234bd1f60c760c1 /scripts | |
parent | 1c72d225c53c123ed810cf3f56f5c92603f7f791 (diff) |
Initial changes in preparation for half-precision fp16 support
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/generator/datatype.py | 3 | ||||
-rw-r--r-- | scripts/generator/generator.py | 114 | ||||
-rw-r--r-- | scripts/generator/routine.py | 6 |
3 files changed, 68 insertions, 55 deletions
diff --git a/scripts/generator/datatype.py b/scripts/generator/datatype.py index 5a58ab53..5bff95d1 100644 --- a/scripts/generator/datatype.py +++ b/scripts/generator/datatype.py @@ -13,10 +13,13 @@ # ================================================================================================== # Short-hands for data-types +HLF = "half" FLT = "float" DBL = "double" FLT2 = "float2" DBL2 = "double2" + +HCL = "cl_half" F2CL = "cl_float2" D2CL = "cl_double2" diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 210f371f..bc8fa783 100644 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -28,11 +28,12 @@ import os.path # Local files from routine import Routine -from datatype import DataType, FLT, DBL, FLT2, DBL2, F2CL, D2CL +from datatype import DataType, HLF, FLT, DBL, FLT2, DBL2, HCL, F2CL, D2CL # ================================================================================================== # Regular data-types +H = DataType("H", "H", HLF, [HLF, HLF, HCL, HCL], HLF ) # half (16) S = DataType("S", "S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32) D = DataType("D", "D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64) C = DataType("C", "C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232) @@ -67,7 +68,7 @@ routines = [ Routine(True, True, "1", "swap", T, [S,D,C,Z], ["n"], [], [], ["x","y"], [], "", "Swap two vectors", "Interchanges the contents of vectors x and y.", []), Routine(True, True, "1", "scal", T, [S,D,C,Z], ["n"], [], [], ["x"], ["alpha"], "", "Vector scaling", "Multiplies all elements of vector x by a scalar constant alpha.", []), Routine(True, True, "1", "copy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], [], "", "Vector copy", "Copies the contents of vector x into vector y.", []), - Routine(True, True, "1", "axpy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []), + Routine(True, True, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []), Routine(True, True, "1", "dot", T, [S,D], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two vectors", "Multiplies the vectors x and y element-wise and accumulates the results. The sum is stored in the dot buffer.", []), Routine(True, True, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []), Routine(True, True, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []), @@ -229,22 +230,23 @@ def wrapper_clblas(routines): result = "" for routine in routines: if routine.has_tests: - result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNames()) + result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNamesTested()) if routine.NoScalars(): result += routine.RoutineHeaderWrapperCL(routine.template, True, 21)+";\n" for flavour in routine.flavours: - indent = " "*(17 + routine.Length()) - result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n" - arguments = routine.ArgumentsWrapperCL(flavour) - if routine.scratch: - result += " auto queue = Queue(queues[0]);\n" - result += " auto context = queue.GetContext();\n" - result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n" - arguments += ["scratch_buffer()"] - result += " return clblas"+flavour.name+routine.name+"(" - result += (",\n"+indent).join([a for a in arguments]) - result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);" - result += "\n}\n" + if flavour.precision_name in ["S","D","C","Z"]: + indent = " "*(17 + routine.Length()) + result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n" + arguments = routine.ArgumentsWrapperCL(flavour) + if routine.scratch: + result += " auto queue = Queue(queues[0]);\n" + result += " auto context = queue.GetContext();\n" + result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n" + arguments += ["scratch_buffer()"] + result += " return clblas"+flavour.name+routine.name+"(" + result += (",\n"+indent).join([a for a in arguments]) + result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);" + result += "\n}\n" return result # The wrapper to the reference CBLAS routines (for performance/correctness testing) @@ -252,44 +254,45 @@ def wrapper_cblas(routines): result = "" for routine in routines: if routine.has_tests: - result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNames()) + result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNamesTested()) for flavour in routine.flavours: - indent = " "*(10 + routine.Length()) - result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n" - arguments = routine.ArgumentsWrapperC(flavour) - - # Double-precision scalars - for scalar in routine.scalars: - if flavour.IsComplex(scalar): - result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n" - - # Special case for scalar outputs - assignment = "" - postfix = "" - endofline = "" - extra_argument = "" - for output_buffer in routine.outputs: - if output_buffer in routine.ScalarBuffersFirst(): - if flavour in [C,Z]: - postfix += "_sub" - indent += " " - extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])" - elif output_buffer in routine.IndexBuffers(): - assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = " - indent += " "*len(assignment) - else: - assignment = output_buffer+"_buffer["+output_buffer+"_offset]" - if (flavour.name in ["Sc","Dz"]): - assignment = assignment+".real(" - endofline += ")" + if flavour.precision_name in ["S","D","C","Z"]: + indent = " "*(10 + routine.Length()) + result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n" + arguments = routine.ArgumentsWrapperC(flavour) + + # Double-precision scalars + for scalar in routine.scalars: + if flavour.IsComplex(scalar): + result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n" + + # Special case for scalar outputs + assignment = "" + postfix = "" + endofline = "" + extra_argument = "" + for output_buffer in routine.outputs: + if output_buffer in routine.ScalarBuffersFirst(): + if flavour in [C,Z]: + postfix += "_sub" + indent += " " + extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])" + elif output_buffer in routine.IndexBuffers(): + assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = " + indent += " "*len(assignment) else: - assignment = assignment+" = " - indent += " "*len(assignment) - - result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"(" - result += (",\n"+indent).join([a for a in arguments]) - result += extra_argument+endofline+");" - result += "\n}\n" + assignment = output_buffer+"_buffer["+output_buffer+"_offset]" + if (flavour.name in ["Sc","Dz"]): + assignment = assignment+".real(" + endofline += ")" + else: + assignment = assignment+" = " + indent += " "*len(assignment) + + result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"(" + result += (",\n"+indent).join([a for a in arguments]) + result += extra_argument+endofline+");" + result += "\n}\n" return result # ================================================================================================== @@ -368,9 +371,10 @@ for level in [1,2,3]: body += "int main(int argc, char *argv[]) {\n" not_first = "false" for flavour in routine.flavours: - body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate() - body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n" - not_first = "true" + if flavour.precision_name in ["S","D","C","Z"]: + body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate() + body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n" + not_first = "true" body += " return 0;\n" body += "}\n" f.write(header+"\n") @@ -397,7 +401,7 @@ for level in [1,2,3]: body += " case clblast::Precision::k"+PrecisionToFullName(precision)+":" found = False for flavour in routine.flavours: - if flavour.precision_name == precision: + if flavour.precision_name == precision and flavour.precision_name in ["S","D","C","Z"]: body += "\n clblast::RunClient<clblast::TestX"+routine.name+flavour.TestTemplate() body += ">(argc, argv); break;\n" found = True diff --git a/scripts/generator/routine.py b/scripts/generator/routine.py index e5059c61..fbf5836a 100644 --- a/scripts/generator/routine.py +++ b/scripts/generator/routine.py @@ -119,6 +119,12 @@ class Routine(): def ShortNames(self): return "/".join([f.name+self.name.upper() for f in self.flavours]) + # As above, but excludes some + def ShortNamesTested(self): + names = [f.name+self.name.upper() for f in self.flavours] + if "H"+self.name.upper() in names: names.remove("H"+self.name.upper()) + return "/".join(names) + # Determines which buffers go first (between alpha and beta) and which ones go after def BuffersFirst(self): if self.level == "2b": |