summaryrefslogtreecommitdiff
path: root/scripts/generator/generator
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-11-22 08:41:52 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2016-11-22 08:41:52 +0100
commit26ca07148092b5d4fcb0e25190e07bf6acae25a3 (patch)
tree81854684aa03e09ad471228024f6c986b1b60f55 /scripts/generator/generator
parenteefe0df43575686c6aa48a9fb6e25e27bef1af40 (diff)
Minor changes to ensure full compatibility with the Netlib CBLAS API
Diffstat (limited to 'scripts/generator/generator')
-rw-r--r--scripts/generator/generator/cpp.py7
-rw-r--r--scripts/generator/generator/routine.py33
2 files changed, 31 insertions, 9 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index 7b7ece22..6bb3080f 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -112,6 +112,7 @@ def clblast_netlib_c_cc(routine):
# There is a version available in CBLAS
if flavour.precision_name in ["S", "D", "C", "Z"]:
template = "<" + flavour.template + ">" if routine.no_scalars() else ""
+ name_postfix = "_sub" if routine.name in routine.routines_scalar_no_return() else ""
indent = " " * (21 + routine.length() + len(template))
result += routine.routine_header_netlib(flavour, 9, "") + " {" + NL
@@ -129,6 +130,8 @@ def clblast_netlib_c_cc(routine):
for i, name in enumerate(routine.inputs + routine.outputs):
buffer_type = routine.get_buffer_type(name, flavour)
result += " " + routine.create_buffer(name, buffer_type) + NL
+ if name in routine.scalar_buffers_second_non_pointer():
+ result += " " + buffer_type + " " + name + "_vec[1]; " + name + "_vec[0] = " + name + ";" + NL
for name in routine.inputs + routine.outputs:
if name not in routine.scalar_buffers_first():
prefix = "" if name in routine.outputs else "const "
@@ -148,14 +151,14 @@ def clblast_netlib_c_cc(routine):
# Copy back and clean-up
for name in routine.outputs:
- if name in routine.scalar_buffers_first():
+ if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return():
buffer_type = routine.get_buffer_type(name, flavour)
result += " " + buffer_type + " " + name + "[" + name + "_size];" + NL
for name in routine.outputs:
buffer_type = routine.get_buffer_type(name, flavour)
result += " " + routine.read_buffer(name, buffer_type) + NL
for name in routine.outputs:
- if name in routine.scalar_buffers_first():
+ if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return():
result += " return " + name + "[0]"
if flavour.buffer_type in ["float2", "double2"]:
if name not in routine.index_buffers():
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 391cf3e0..6fcce23b 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -43,6 +43,11 @@ class Routine:
return ["sa", "sb", "sc", "ss", "sd1", "sd2", "sx1", "sy1", "sparam"]
@staticmethod
+ def scalar_buffers_second_non_pointer():
+ """As above, but these ones are not passed as pointers but as scalars instead"""
+ return ["sy1"]
+
+ @staticmethod
def other_scalars():
"""List of scalars other than alpha and beta"""
return ["cos", "sin"]
@@ -68,6 +73,10 @@ class Routine:
return ["a", "b", "c", "ap"]
@staticmethod
+ def routines_scalar_no_return():
+ return ["dotu", "dotc"]
+
+ @staticmethod
def set_size(name, size):
"""Sets the size of a buffer"""
return "const auto " + name + "_size = " + size + ";"
@@ -77,10 +86,12 @@ class Routine:
"""Creates a new CLCudaAPI buffer"""
return "auto " + name + "_buffer = clblast::Buffer<" + template + ">(context, " + name + "_size);"
- @staticmethod
- def write_buffer(name, template):
+ def write_buffer(self, name, template):
"""Writes to a CLCudaAPI buffer"""
- data_structure = "reinterpret_cast<" + template + "*>(" + name + ")"
+ postfix = ""
+ if name in self.scalar_buffers_second_non_pointer():
+ postfix = "_vec"
+ data_structure = "reinterpret_cast<" + template + "*>(" + name + postfix + ")"
return name + "_buffer.Write(queue, " + name + "_size, " + data_structure + ");"
@staticmethod
@@ -206,7 +217,8 @@ class Routine:
prefix = "const " if name in self.inputs else ""
if name in self.inputs or name in self.outputs:
data_type = "void" if flavour.is_non_standard() else flavour.buffer_type
- a = [prefix + data_type + "* " + name + ""]
+ pointer = "" if name in self.scalar_buffers_second_non_pointer() else "*"
+ a = [prefix + data_type + pointer + " " + name + ""]
c = ["const int " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
return [", ".join(a + c)]
return []
@@ -553,13 +565,16 @@ class Routine:
def arguments_def_netlib(self, flavour):
"""As above, but for the Netlib CBLAS API"""
- return (self.options_def_c() + self.sizes_def_netlib() +
+ result=(self.options_def_c() + self.sizes_def_netlib() +
self.scalar_def_void("alpha", flavour) +
list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.buffers_first()])) +
self.scalar_def_void("beta", flavour) +
list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.buffers_second()])) +
list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.scalar_buffers_second()])) +
list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])))
+ if self.name in self.routines_scalar_no_return():
+ result += list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.scalar_buffers_first()]))
+ return result
def arguments_def_c(self, flavour):
"""As above, but for the C API"""
@@ -654,11 +669,15 @@ class Routine:
if output in self.index_buffers():
return_type = "int"
break
- if output in self.scalar_buffers_first():
+ if output in self.scalar_buffers_first() and self.name not in self.routines_scalar_no_return():
return_type = flavour.buffer_type.replace("2", "")
break
indent = " " * (spaces + len(return_type) + self.length())
- result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + self.name + "("
+ routine_name = self.name
+ if self.name in self.routines_scalar_no_return():
+ routine_name += "_sub"
+ indent += " "
+ result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + routine_name + "("
result += (",\n" + indent).join([a for a in self.arguments_def_netlib(flavour)]) + ")"
return result