summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/cpp.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r--scripts/generator/generator/cpp.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index 3631e737..51ca047c 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -58,7 +58,7 @@ def clblast_cc(routine, cuda=False):
result += " auto queue_cpp = Queue(*queue);" + NL
event = "nullptr" if cuda else "event"
result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL
- if routine.batched:
+ if routine.batched == 1:
result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
if routine.temp_buffer:
null = "0" if cuda else "nullptr"
@@ -110,7 +110,7 @@ def clblast_c_cc(routine):
template = "<" + flavour.template + ">" if routine.no_scalars() else ""
indent = " " * (16 + routine.length() + len(template))
result += routine.routine_header_c(flavour, 27, "") + " {" + NL
- if routine.batched:
+ if routine.batched == 1:
result += " " + (NL + " ").join(routine.batched_transform_to_complex(flavour)) + NL
result += " try {" + NL
result += " return static_cast<CLBlastStatusCode>(" + NL
@@ -388,7 +388,7 @@ def performance_test(routine, level_string):
found = False
for flavour in routine.flavours:
if flavour.precision_name == precision:
- extra_template_argument = "0, " if routine.name == "gemm" and not routine.batched else ""
+ extra_template_argument = "0, " if routine.name == "gemm" and routine.batched == 0 else ""
result += NL + " clblast::RunClient<clblast::TestX" + routine.plain_name()
result += flavour.test_template(extra_template_argument)
result += ">(argc, argv); break;" + NL
@@ -410,7 +410,7 @@ def correctness_test(routine, level_string):
result += "int main(int argc, char *argv[]) {" + NL
result += " auto errors = size_t{0};" + NL
not_first = "false"
- extra_template_arguments = ["1, ", "2, "] if routine.name == "gemm" and not routine.batched else [""]
+ extra_template_arguments = ["1, ", "2, "] if routine.name == "gemm" and routine.batched == 0 else [""]
for extra_template_argument in extra_template_arguments:
for flavour in routine.flavours:
result += " errors += clblast::RunTests<clblast::TestX" + routine.plain_name()