diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-11-22 08:41:52 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-11-22 08:41:52 +0100 |
commit | 26ca07148092b5d4fcb0e25190e07bf6acae25a3 (patch) | |
tree | 81854684aa03e09ad471228024f6c986b1b60f55 /scripts/generator/generator/cpp.py | |
parent | eefe0df43575686c6aa48a9fb6e25e27bef1af40 (diff) |
Minor changes to ensure full compatibility with the Netlib CBLAS API
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r-- | scripts/generator/generator/cpp.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 7b7ece22..6bb3080f 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -112,6 +112,7 @@ def clblast_netlib_c_cc(routine): # There is a version available in CBLAS if flavour.precision_name in ["S", "D", "C", "Z"]: template = "<" + flavour.template + ">" if routine.no_scalars() else "" + name_postfix = "_sub" if routine.name in routine.routines_scalar_no_return() else "" indent = " " * (21 + routine.length() + len(template)) result += routine.routine_header_netlib(flavour, 9, "") + " {" + NL @@ -129,6 +130,8 @@ def clblast_netlib_c_cc(routine): for i, name in enumerate(routine.inputs + routine.outputs): buffer_type = routine.get_buffer_type(name, flavour) result += " " + routine.create_buffer(name, buffer_type) + NL + if name in routine.scalar_buffers_second_non_pointer(): + result += " " + buffer_type + " " + name + "_vec[1]; " + name + "_vec[0] = " + name + ";" + NL for name in routine.inputs + routine.outputs: if name not in routine.scalar_buffers_first(): prefix = "" if name in routine.outputs else "const " @@ -148,14 +151,14 @@ def clblast_netlib_c_cc(routine): # Copy back and clean-up for name in routine.outputs: - if name in routine.scalar_buffers_first(): + if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return(): buffer_type = routine.get_buffer_type(name, flavour) result += " " + buffer_type + " " + name + "[" + name + "_size];" + NL for name in routine.outputs: buffer_type = routine.get_buffer_type(name, flavour) result += " " + routine.read_buffer(name, buffer_type) + NL for name in routine.outputs: - if name in routine.scalar_buffers_first(): + if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return(): result += " return " + name + "[0]" if flavour.buffer_type in ["float2", "double2"]: if name not in routine.index_buffers(): |