diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-04 19:33:51 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-04 19:33:51 +0100 |
commit | 44431daecc63cc4ead3208327bcd70834b3f4bdb (patch) | |
tree | 48aaac483856bbe6aa5ce54166d7adb8897e1aa5 /scripts | |
parent | af14fff1e9f93daa535b673ad1391fac397b5edc (diff) |
Added a CUDA version of the GEMM temp-buffer optional argument
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/generator/generator/cpp.py | 6 | ||||
-rw-r--r-- | scripts/generator/generator/routine.py | 2 |
2 files changed, 5 insertions, 3 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index a850a032..656253d7 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -60,12 +60,12 @@ def clblast_cc(routine, cuda=False): result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL if routine.batched: result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL - if routine.temp_buffer and not cuda: + if routine.temp_buffer: result += " const auto temp_buffer_provided = temp_buffer != nullptr;\n" result += " auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);\n" result += " routine.Do" + routine.capitalized_name() + "(" result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()]) - if routine.temp_buffer and not cuda: + if routine.temp_buffer: result += ",\n" + indent1 + "temp_buffer_cpp, temp_buffer_provided" result += ");" + NL result += " return StatusCode::kSuccess;" + NL @@ -84,6 +84,8 @@ def clblast_cc(routine, cuda=False): result += "," + NL + indent2 if cuda: result += "const CUcontext, const CUdevice" + if routine.temp_buffer: + result += ", CUdeviceptr" else: result += "cl_command_queue*, cl_event*" if routine.temp_buffer: diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 11b9080f..22be02b0 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -819,7 +819,7 @@ class Routine: result += "const CUcontext context, const CUdevice device" else: result += "cl_command_queue* queue, cl_event* event" + default_event - if self.temp_buffer and not cuda: + if self.temp_buffer: result += ",\n" + indent + mem_type + " temp_buffer" if not implementation: result += " = nullptr" |