diff options
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r-- | scripts/generator/generator/cpp.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 6dc3fc93..e32738ee 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -226,7 +226,10 @@ def wrapper_clblas(routine): # Convert to float (note: also integer buffers are stored as half/float) for buf in routine.inputs + routine.outputs: - result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer, queues[0]);" + NL + if buf not in routine.index_buffers(): + result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer, queues[0]);" + NL + else: + result += " auto " + buf + "_buffer_bis = " + buf + "_buffer;" + NL # Call the float routine result += " auto status = clblasX" + routine.name + "(" @@ -236,7 +239,8 @@ def wrapper_clblas(routine): # Convert back to half for buf in routine.outputs: - result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis, queues[0]);" + NL + if buf not in routine.index_buffers(): + result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis, queues[0]);" + NL result += " return status;" # Complete @@ -276,10 +280,6 @@ def wrapper_cblas(routine): extra_argument += "," + NL + indent extra_argument += "reinterpret_cast<return_pointer_" + flavour.buffer_type[:-1] + ">" extra_argument += "(&" + output_buffer + "_buffer[" + output_buffer + "_offset])" - elif output_buffer in routine.index_buffers(): - assignment = "reinterpret_cast<int*>(&" + output_buffer + "_buffer[0])[" + output_buffer + "_offset] = static_cast<int>(" - postpostfix = ")" - indent += " " * (len(assignment) + 1) else: assignment = output_buffer + "_buffer[" + output_buffer + "_offset]" if flavour.name in ["Sc", "Dz"]: @@ -299,7 +299,10 @@ def wrapper_cblas(routine): # Convert to float (note: also integer buffers are stored as half/float) for buf in routine.inputs + routine.outputs: - result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer);" + NL + if buf not in routine.index_buffers(): + result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer);" + NL + else: + result += " auto " + buf + "_buffer_bis = " + buf + "_buffer;" + NL # Call the float routine result += " cblasX" + routine.name + "(" @@ -308,7 +311,8 @@ def wrapper_cblas(routine): # Convert back to half for buf in routine.outputs: - result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis);" + NL + if buf not in routine.index_buffers(): + result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis);" + NL # Complete result += "}" + NL |