summaryrefslogtreecommitdiff
path: root/scripts/generator
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-05-26 23:36:19 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-05-26 23:36:19 +0200
commit03182f9d07533f795a498936391da744d982e8e2 (patch)
tree3a73046809927abd1000fe3309f37787d1791976 /scripts/generator
parentb487d4dd44179293c9e08ddf2ce3ed902fa749c8 (diff)
Added half-precision tests for the clBLAS reference through conversion to single-precision
Diffstat (limited to 'scripts/generator')
-rw-r--r--scripts/generator/generator.py29
-rw-r--r--scripts/generator/routine.py20
2 files changed, 40 insertions, 9 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 3d07c5a3..f5fc5ecf 100644
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -235,9 +235,11 @@ def wrapper_clblas(routines):
if routine.NoScalars():
result += routine.RoutineHeaderWrapperCL(routine.template, True, 21)+";\n"
for flavour in routine.flavours:
- indent = " "*(17 + routine.Length())
result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n"
+
+ # There is a version available in clBLAS
if flavour.precision_name in ["S","D","C","Z"]:
+ indent = " "*(17 + routine.Length())
arguments = routine.ArgumentsWrapperCL(flavour)
if routine.scratch:
result += " auto queue = Queue(queues[0]);\n"
@@ -247,8 +249,27 @@ def wrapper_clblas(routines):
result += " return clblas"+flavour.name+routine.name+"("
result += (",\n"+indent).join([a for a in arguments])
result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
- else:
- result += " return clblasNotImplemented;"
+
+ # There is no clBLAS available, forward the call to one of the available functions
+ else: # Half-precision
+ indent = " "*(24 + routine.Length())
+
+ # Convert to float (note: also integer buffers are stored as half/float)
+ for buf in routine.inputs + routine.outputs:
+ result += " auto "+buf+"_buffer_bis = HalfToFloatBuffer("+buf+"_buffer, queues[0]);\n"
+
+ # Call the float routine
+ result += " auto status = clblasX"+routine.name+"("
+ result += (",\n"+indent).join([a for a in routine.ArgumentsHalf()])
+ result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
+ result += "\n"
+
+ # Convert back to half
+ for buf in routine.outputs:
+ result += " FloatToHalfBuffer("+buf+"_buffer, "+buf+"_buffer_bis, queues[0]);\n"
+ result += " return status;"
+
+ # Complete
result += "\n}\n"
return result
@@ -336,7 +357,7 @@ files = [
path_clblast+"/test/wrapper_clblas.h",
path_clblast+"/test/wrapper_cblas.h",
]
-header_lines = [84, 71, 93, 22, 29, 51]
+header_lines = [84, 71, 93, 22, 29, 41]
footer_lines = [17, 71, 19, 14, 6, 6]
# Checks whether the command-line arguments are valid; exists otherwise
diff --git a/scripts/generator/routine.py b/scripts/generator/routine.py
index a347de0e..fe857ea8 100644
--- a/scripts/generator/routine.py
+++ b/scripts/generator/routine.py
@@ -185,6 +185,16 @@ class Routine():
return [", ".join(a+b+c)]
return []
+ # As above but with data-types
+ def BufferDefWrapperCL(self, name, flavour):
+ prefix = "const " if (name in self.inputs) else ""
+ if (name in self.inputs) or (name in self.outputs):
+ a = [prefix+"Buffer<"+flavour.buffertype+">& "+name+"_buffer"]
+ b = ["const size_t "+name+"_offset"]
+ c = ["const size_t "+name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else []
+ return [", ".join(a+b+c)]
+ return []
+
# As above but as vectors
def BufferDefVector(self, name, flavour):
prefix = "const " if (name in self.inputs) else ""
@@ -208,7 +218,7 @@ class Routine():
# As above but with a static cast for clBLAS wrapper
def BufferWrapperCL(self, name):
if (name in self.inputs) or (name in self.outputs):
- a = [name+"_buffer"]
+ a = [name+"_buffer()"]
b = [name+"_offset"]
c = []
if (name in ["x","y"]):
@@ -491,12 +501,12 @@ class Routine():
# As above, but clBLAS wrapper plain datatypes
def ArgumentsDefWrapperCL(self, flavour):
return (self.OptionsDefWrapperCL() + self.SizesDef() +
- list(chain(*[self.BufferDef(b) for b in self.ScalarBuffersFirst()])) +
+ list(chain(*[self.BufferDefWrapperCL(b, flavour) for b in self.ScalarBuffersFirst()])) +
self.ScalarDefPlain("alpha", flavour) +
- list(chain(*[self.BufferDef(b) for b in self.BuffersFirst()])) +
+ list(chain(*[self.BufferDefWrapperCL(b, flavour) for b in self.BuffersFirst()])) +
self.ScalarDefPlain("beta", flavour) +
- list(chain(*[self.BufferDef(b) for b in self.BuffersSecond()])) +
- list(chain(*[self.BufferDef(b) for b in self.ScalarBuffersSecond()])) +
+ list(chain(*[self.BufferDefWrapperCL(b, flavour) for b in self.BuffersSecond()])) +
+ list(chain(*[self.BufferDefWrapperCL(b, flavour) for b in self.ScalarBuffersSecond()])) +
list(chain(*[self.ScalarDefPlain(s, flavour) for s in self.OtherScalars()])))
# As above, but CBLAS wrapper plain datatypes