summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/cpp.py
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-11-22 08:41:52 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2016-11-22 08:41:52 +0100
commit26ca07148092b5d4fcb0e25190e07bf6acae25a3 (patch)
tree81854684aa03e09ad471228024f6c986b1b60f55 /scripts/generator/generator/cpp.py
parenteefe0df43575686c6aa48a9fb6e25e27bef1af40 (diff)
Minor changes to ensure full compatibility with the Netlib CBLAS API
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r--scripts/generator/generator/cpp.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index 7b7ece22..6bb3080f 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -112,6 +112,7 @@ def clblast_netlib_c_cc(routine):
# There is a version available in CBLAS
if flavour.precision_name in ["S", "D", "C", "Z"]:
template = "<" + flavour.template + ">" if routine.no_scalars() else ""
+ name_postfix = "_sub" if routine.name in routine.routines_scalar_no_return() else ""
indent = " " * (21 + routine.length() + len(template))
result += routine.routine_header_netlib(flavour, 9, "") + " {" + NL
@@ -129,6 +130,8 @@ def clblast_netlib_c_cc(routine):
for i, name in enumerate(routine.inputs + routine.outputs):
buffer_type = routine.get_buffer_type(name, flavour)
result += " " + routine.create_buffer(name, buffer_type) + NL
+ if name in routine.scalar_buffers_second_non_pointer():
+ result += " " + buffer_type + " " + name + "_vec[1]; " + name + "_vec[0] = " + name + ";" + NL
for name in routine.inputs + routine.outputs:
if name not in routine.scalar_buffers_first():
prefix = "" if name in routine.outputs else "const "
@@ -148,14 +151,14 @@ def clblast_netlib_c_cc(routine):
# Copy back and clean-up
for name in routine.outputs:
- if name in routine.scalar_buffers_first():
+ if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return():
buffer_type = routine.get_buffer_type(name, flavour)
result += " " + buffer_type + " " + name + "[" + name + "_size];" + NL
for name in routine.outputs:
buffer_type = routine.get_buffer_type(name, flavour)
result += " " + routine.read_buffer(name, buffer_type) + NL
for name in routine.outputs:
- if name in routine.scalar_buffers_first():
+ if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return():
result += " return " + name + "[0]"
if flavour.buffer_type in ["float2", "double2"]:
if name not in routine.index_buffers():