diff options
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r-- | scripts/generator/generator/cpp.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 3631e737..51ca047c 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -58,7 +58,7 @@ def clblast_cc(routine, cuda=False): result += " auto queue_cpp = Queue(*queue);" + NL event = "nullptr" if cuda else "event" result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL - if routine.batched: + if routine.batched == 1: result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL if routine.temp_buffer: null = "0" if cuda else "nullptr" @@ -110,7 +110,7 @@ def clblast_c_cc(routine): template = "<" + flavour.template + ">" if routine.no_scalars() else "" indent = " " * (16 + routine.length() + len(template)) result += routine.routine_header_c(flavour, 27, "") + " {" + NL - if routine.batched: + if routine.batched == 1: result += " " + (NL + " ").join(routine.batched_transform_to_complex(flavour)) + NL result += " try {" + NL result += " return static_cast<CLBlastStatusCode>(" + NL @@ -388,7 +388,7 @@ def performance_test(routine, level_string): found = False for flavour in routine.flavours: if flavour.precision_name == precision: - extra_template_argument = "0, " if routine.name == "gemm" and not routine.batched else "" + extra_template_argument = "0, " if routine.name == "gemm" and routine.batched == 0 else "" result += NL + " clblast::RunClient<clblast::TestX" + routine.plain_name() result += flavour.test_template(extra_template_argument) result += ">(argc, argv); break;" + NL @@ -410,7 +410,7 @@ def correctness_test(routine, level_string): result += "int main(int argc, char *argv[]) {" + NL result += " auto errors = size_t{0};" + NL not_first = "false" - extra_template_arguments = ["1, ", "2, "] if routine.name == "gemm" and not routine.batched else [""] + extra_template_arguments = ["1, ", "2, "] if routine.name == "gemm" and routine.batched == 0 else [""] for extra_template_argument in extra_template_arguments: for flavour in routine.flavours: result += " errors += clblast::RunTests<clblast::TestX" + routine.plain_name() |