summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/routine.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r--scripts/generator/generator/routine.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 8b6ab57f..c2201c0d 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -282,7 +282,10 @@ class Routine:
"""As above but for OpenCL"""
prefix = "const " if name in self.inputs else ""
if name in self.inputs or name in self.outputs:
- a = [prefix + "Buffer<" + flavour.buffer_type + ">& " + name + "_buffer"]
+ if name == "imax":
+ a = [prefix + "Buffer<unsigned int>& " + name + "_buffer"]
+ else:
+ a = [prefix + "Buffer<" + flavour.buffer_type + ">& " + name + "_buffer"]
b = ["const size_t " + name + "_offset"]
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
return [", ".join(a + b + c)]
@@ -292,7 +295,10 @@ class Routine:
"""As above but for CUDA"""
prefix = "const " if name in self.inputs else ""
if name in self.inputs or name in self.outputs:
- a = [prefix + flavour.buffer_type + "* " + name + "_buffer"]
+ if name == "imax":
+ a = [prefix + "unsigned int * " + name + "_buffer"]
+ else:
+ a = [prefix + flavour.buffer_type + "* " + name + "_buffer"]
b = ["const size_t " + name + "_offset"]
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
return [", ".join(a + b + c)]
@@ -302,7 +308,10 @@ class Routine:
"""As above but as vectors"""
prefix = "const " if name in self.inputs else ""
if name in self.inputs or name in self.outputs:
- a = [prefix + "std::vector<" + flavour.buffer_type + ">& " + name + "_buffer"]
+ if name == "imax":
+ a = [prefix + "std::vector<unsigned int>& " + name + "_buffer"]
+ else:
+ a = [prefix + "std::vector<" + flavour.buffer_type + ">& " + name + "_buffer"]
b = ["const size_t " + name + "_offset"]
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
return [", ".join(a + b + c)]