summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-04 19:33:51 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-04 19:33:51 +0100
commit44431daecc63cc4ead3208327bcd70834b3f4bdb (patch)
tree48aaac483856bbe6aa5ce54166d7adb8897e1aa5 /scripts
parentaf14fff1e9f93daa535b673ad1391fac397b5edc (diff)
Added a CUDA version of the GEMM temp-buffer optional argument
Diffstat (limited to 'scripts')
-rw-r--r--scripts/generator/generator/cpp.py6
-rw-r--r--scripts/generator/generator/routine.py2
2 files changed, 5 insertions, 3 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index a850a032..656253d7 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -60,12 +60,12 @@ def clblast_cc(routine, cuda=False):
result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL
if routine.batched:
result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
- if routine.temp_buffer and not cuda:
+ if routine.temp_buffer:
result += " const auto temp_buffer_provided = temp_buffer != nullptr;\n"
result += " auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);\n"
result += " routine.Do" + routine.capitalized_name() + "("
result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()])
- if routine.temp_buffer and not cuda:
+ if routine.temp_buffer:
result += ",\n" + indent1 + "temp_buffer_cpp, temp_buffer_provided"
result += ");" + NL
result += " return StatusCode::kSuccess;" + NL
@@ -84,6 +84,8 @@ def clblast_cc(routine, cuda=False):
result += "," + NL + indent2
if cuda:
result += "const CUcontext, const CUdevice"
+ if routine.temp_buffer:
+ result += ", CUdeviceptr"
else:
result += "cl_command_queue*, cl_event*"
if routine.temp_buffer:
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 11b9080f..22be02b0 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -819,7 +819,7 @@ class Routine:
result += "const CUcontext context, const CUdevice device"
else:
result += "cl_command_queue* queue, cl_event* event" + default_event
- if self.temp_buffer and not cuda:
+ if self.temp_buffer:
result += ",\n" + indent + mem_type + " temp_buffer"
if not implementation:
result += " = nullptr"