summaryrefslogtreecommitdiff
path: root/scripts/generator
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-10-12 12:20:43 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-10-12 12:20:43 +0200
commitcc5b4754250b3c03b9b0f8d72f32d1eacac15b18 (patch)
tree747a5ad136f708de3559c061243e5f31bc17977a /scripts/generator
parentb901809345848b44442c787380b13db5e5156df0 (diff)
CUDA API now takes context and device in instead of stream
Diffstat (limited to 'scripts/generator')
-rw-r--r--scripts/generator/generator/cpp.py9
-rw-r--r--scripts/generator/generator/routine.py4
2 files changed, 9 insertions, 4 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index f1ee1959..5413906a 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -50,7 +50,12 @@ def clblast_cc(routine, cuda=False):
if routine.implemented:
result += routine.routine_header_cpp(12, "", cuda) + " {" + NL
result += " try {" + NL
- result += " auto queue_cpp = Queue(*queue);" + NL
+ if cuda:
+ result += " const auto context_cpp = Context(context);" + NL
+ result += " const auto device_cpp = Device(device);" + NL
+ result += " auto queue_cpp = Queue(context_cpp, device_cpp);" + NL
+ else:
+ result += " auto queue_cpp = Queue(*queue);" + NL
result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, event);" + NL
if routine.batched:
result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
@@ -72,7 +77,7 @@ def clblast_cc(routine, cuda=False):
result += ("," + NL + indent2).join([a for a in arguments])
result += "," + NL + indent2
if cuda:
- result += "CUstream*"
+ result += "const CUcontext, const CUdevice"
else:
result += "cl_command_queue*, cl_event*"
result += ");" + NL
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index c3c1f775..b6b55821 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -813,7 +813,7 @@ class Routine:
result += (",\n" + indent).join([a for a in arguments])
result += ",\n" + indent
if cuda:
- result += "CUstream* stream"
+ result += "const CUcontext context, const CUdevice device"
else:
result += "cl_command_queue* queue, cl_event* event" + default_event
result += ")"
@@ -830,7 +830,7 @@ class Routine:
result += (",\n" + indent).join([a for a in arguments])
result += ",\n" + indent
if cuda:
- result += "CUstream* stream"
+ result += "const CUcontext, const CUdevice"
else:
result += "cl_command_queue*, cl_event*"
result += ")"