summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-05-12 19:56:21 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-05-12 19:56:21 +0200
commitf2ba75890c522b4fe1762bfeac3e08667cf9588a (patch)
tree82e22cb72fbfb135570ce3bf3234bd1f60c760c1 /scripts
parent1c72d225c53c123ed810cf3f56f5c92603f7f791 (diff)
Initial changes in preparation for half-precision fp16 support
Diffstat (limited to 'scripts')
-rw-r--r--scripts/generator/datatype.py3
-rw-r--r--scripts/generator/generator.py114
-rw-r--r--scripts/generator/routine.py6
3 files changed, 68 insertions, 55 deletions
diff --git a/scripts/generator/datatype.py b/scripts/generator/datatype.py
index 5a58ab53..5bff95d1 100644
--- a/scripts/generator/datatype.py
+++ b/scripts/generator/datatype.py
@@ -13,10 +13,13 @@
# ==================================================================================================
# Short-hands for data-types
+HLF = "half"
FLT = "float"
DBL = "double"
FLT2 = "float2"
DBL2 = "double2"
+
+HCL = "cl_half"
F2CL = "cl_float2"
D2CL = "cl_double2"
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 210f371f..bc8fa783 100644
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -28,11 +28,12 @@ import os.path
# Local files
from routine import Routine
-from datatype import DataType, FLT, DBL, FLT2, DBL2, F2CL, D2CL
+from datatype import DataType, HLF, FLT, DBL, FLT2, DBL2, HCL, F2CL, D2CL
# ==================================================================================================
# Regular data-types
+H = DataType("H", "H", HLF, [HLF, HLF, HCL, HCL], HLF ) # half (16)
S = DataType("S", "S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32)
D = DataType("D", "D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64)
C = DataType("C", "C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232)
@@ -67,7 +68,7 @@ routines = [
Routine(True, True, "1", "swap", T, [S,D,C,Z], ["n"], [], [], ["x","y"], [], "", "Swap two vectors", "Interchanges the contents of vectors x and y.", []),
Routine(True, True, "1", "scal", T, [S,D,C,Z], ["n"], [], [], ["x"], ["alpha"], "", "Vector scaling", "Multiplies all elements of vector x by a scalar constant alpha.", []),
Routine(True, True, "1", "copy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], [], "", "Vector copy", "Copies the contents of vector x into vector y.", []),
- Routine(True, True, "1", "axpy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []),
+ Routine(True, True, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []),
Routine(True, True, "1", "dot", T, [S,D], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two vectors", "Multiplies the vectors x and y element-wise and accumulates the results. The sum is stored in the dot buffer.", []),
Routine(True, True, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
Routine(True, True, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
@@ -229,22 +230,23 @@ def wrapper_clblas(routines):
result = ""
for routine in routines:
if routine.has_tests:
- result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNames())
+ result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNamesTested())
if routine.NoScalars():
result += routine.RoutineHeaderWrapperCL(routine.template, True, 21)+";\n"
for flavour in routine.flavours:
- indent = " "*(17 + routine.Length())
- result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n"
- arguments = routine.ArgumentsWrapperCL(flavour)
- if routine.scratch:
- result += " auto queue = Queue(queues[0]);\n"
- result += " auto context = queue.GetContext();\n"
- result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n"
- arguments += ["scratch_buffer()"]
- result += " return clblas"+flavour.name+routine.name+"("
- result += (",\n"+indent).join([a for a in arguments])
- result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
- result += "\n}\n"
+ if flavour.precision_name in ["S","D","C","Z"]:
+ indent = " "*(17 + routine.Length())
+ result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n"
+ arguments = routine.ArgumentsWrapperCL(flavour)
+ if routine.scratch:
+ result += " auto queue = Queue(queues[0]);\n"
+ result += " auto context = queue.GetContext();\n"
+ result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n"
+ arguments += ["scratch_buffer()"]
+ result += " return clblas"+flavour.name+routine.name+"("
+ result += (",\n"+indent).join([a for a in arguments])
+ result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
+ result += "\n}\n"
return result
# The wrapper to the reference CBLAS routines (for performance/correctness testing)
@@ -252,44 +254,45 @@ def wrapper_cblas(routines):
result = ""
for routine in routines:
if routine.has_tests:
- result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNames())
+ result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNamesTested())
for flavour in routine.flavours:
- indent = " "*(10 + routine.Length())
- result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n"
- arguments = routine.ArgumentsWrapperC(flavour)
-
- # Double-precision scalars
- for scalar in routine.scalars:
- if flavour.IsComplex(scalar):
- result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n"
-
- # Special case for scalar outputs
- assignment = ""
- postfix = ""
- endofline = ""
- extra_argument = ""
- for output_buffer in routine.outputs:
- if output_buffer in routine.ScalarBuffersFirst():
- if flavour in [C,Z]:
- postfix += "_sub"
- indent += " "
- extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])"
- elif output_buffer in routine.IndexBuffers():
- assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = "
- indent += " "*len(assignment)
- else:
- assignment = output_buffer+"_buffer["+output_buffer+"_offset]"
- if (flavour.name in ["Sc","Dz"]):
- assignment = assignment+".real("
- endofline += ")"
+ if flavour.precision_name in ["S","D","C","Z"]:
+ indent = " "*(10 + routine.Length())
+ result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n"
+ arguments = routine.ArgumentsWrapperC(flavour)
+
+ # Double-precision scalars
+ for scalar in routine.scalars:
+ if flavour.IsComplex(scalar):
+ result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n"
+
+ # Special case for scalar outputs
+ assignment = ""
+ postfix = ""
+ endofline = ""
+ extra_argument = ""
+ for output_buffer in routine.outputs:
+ if output_buffer in routine.ScalarBuffersFirst():
+ if flavour in [C,Z]:
+ postfix += "_sub"
+ indent += " "
+ extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])"
+ elif output_buffer in routine.IndexBuffers():
+ assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = "
+ indent += " "*len(assignment)
else:
- assignment = assignment+" = "
- indent += " "*len(assignment)
-
- result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"("
- result += (",\n"+indent).join([a for a in arguments])
- result += extra_argument+endofline+");"
- result += "\n}\n"
+ assignment = output_buffer+"_buffer["+output_buffer+"_offset]"
+ if (flavour.name in ["Sc","Dz"]):
+ assignment = assignment+".real("
+ endofline += ")"
+ else:
+ assignment = assignment+" = "
+ indent += " "*len(assignment)
+
+ result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"("
+ result += (",\n"+indent).join([a for a in arguments])
+ result += extra_argument+endofline+");"
+ result += "\n}\n"
return result
# ==================================================================================================
@@ -368,9 +371,10 @@ for level in [1,2,3]:
body += "int main(int argc, char *argv[]) {\n"
not_first = "false"
for flavour in routine.flavours:
- body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
- body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
- not_first = "true"
+ if flavour.precision_name in ["S","D","C","Z"]:
+ body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
+ body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
+ not_first = "true"
body += " return 0;\n"
body += "}\n"
f.write(header+"\n")
@@ -397,7 +401,7 @@ for level in [1,2,3]:
body += " case clblast::Precision::k"+PrecisionToFullName(precision)+":"
found = False
for flavour in routine.flavours:
- if flavour.precision_name == precision:
+ if flavour.precision_name == precision and flavour.precision_name in ["S","D","C","Z"]:
body += "\n clblast::RunClient<clblast::TestX"+routine.name+flavour.TestTemplate()
body += ">(argc, argv); break;\n"
found = True
diff --git a/scripts/generator/routine.py b/scripts/generator/routine.py
index e5059c61..fbf5836a 100644
--- a/scripts/generator/routine.py
+++ b/scripts/generator/routine.py
@@ -119,6 +119,12 @@ class Routine():
def ShortNames(self):
return "/".join([f.name+self.name.upper() for f in self.flavours])
+ # As above, but excludes some
+ def ShortNamesTested(self):
+ names = [f.name+self.name.upper() for f in self.flavours]
+ if "H"+self.name.upper() in names: names.remove("H"+self.name.upper())
+ return "/".join(names)
+
# Determines which buffers go first (between alpha and beta) and which ones go after
def BuffersFirst(self):
if self.level == "2b":