diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-05 10:38:38 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-05 10:38:38 +0100 |
commit | f9a520b3aff7b4eec99d9e11a03f9467e7ab351c (patch) | |
tree | 90612ac786448fa6e76681ecf6755f1c35c458a4 /scripts/generator/generator/cpp.py | |
parent | 37228c90988509acef9e8a892a752300b7645210 (diff) |
Prepared generator for batched routines; added batched AXPY routine interface
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r-- | scripts/generator/generator/cpp.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index c14d00a1..91fdf458 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -51,8 +51,10 @@ def clblast_cc(routine): result += routine.routine_header_cpp(12, "") + " {" + NL result += " try {" + NL result += " auto queue_cpp = Queue(*queue);" + NL - result += " auto routine = X" + routine.name + "<" + routine.template.template + ">(queue_cpp, event);" + NL - result += " routine.Do" + routine.name.capitalize() + "(" + result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, event);" + NL + if routine.batched: + result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL + result += " routine.Do" + routine.capitalized_name() + "(" result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()]) result += ");" + NL result += " return StatusCode::kSuccess;" + NL @@ -63,7 +65,7 @@ def clblast_cc(routine): result += "}" + NL for flavour in routine.flavours: indent2 = " " * (34 + routine.length() + len(flavour.template)) - result += "template StatusCode PUBLIC_API " + routine.name.capitalize() + "<" + flavour.template + ">(" + result += "template StatusCode PUBLIC_API " + routine.capitalized_name() + "<" + flavour.template + ">(" result += ("," + NL + indent2).join([a for a in routine.arguments_type(flavour)]) result += "," + NL + indent2 + "cl_command_queue*, cl_event*);" + NL return result @@ -84,9 +86,11 @@ def clblast_c_cc(routine): template = "<" + flavour.template + ">" if routine.no_scalars() else "" indent = " " * (16 + routine.length() + len(template)) result += routine.routine_header_c(flavour, 27, "") + " {" + NL + if routine.batched: + result += " " + (NL + " ").join(routine.batched_transform_to_complex(flavour)) + NL result += " try {" + NL result += " return static_cast<CLBlastStatusCode>(" + NL - result += " clblast::" + routine.name.capitalize() + template + "(" + result += " clblast::" + routine.capitalized_name() + template + "(" result += ("," + NL + indent).join([a for a in routine.arguments_cast(flavour, indent)]) result += "," + NL + indent + "queue, event)" + NL result += " );" + NL @@ -290,7 +294,7 @@ def performance_test(routine, level_string): """Generates the body of a performance test for a specific routine""" result = "" result += "#include \"test/performance/client.hpp\"" + NL - result += "#include \"test/routines/level" + level_string + "/x" + routine.name + ".hpp\"" + NL + NL + result += "#include \"test/routines/level" + level_string + "/x" + routine.lowercase_name() + ".hpp\"" + NL + NL result += "// Shortcuts to the clblast namespace" + NL result += "using float2 = clblast::float2;" + NL result += "using double2 = clblast::double2;" + NL + NL @@ -304,7 +308,7 @@ def performance_test(routine, level_string): found = False for flavour in routine.flavours: if flavour.precision_name == precision: - result += NL + " clblast::RunClient<clblast::TestX" + routine.name + flavour.test_template() + result += NL + " clblast::RunClient<clblast::TestX" + routine.plain_name() + flavour.test_template() result += ">(argc, argv); break;" + NL found = True if not found: @@ -319,7 +323,7 @@ def correctness_test(routine, level_string): """Generates the body of a correctness test for a specific routine""" result = "" result += "#include \"test/correctness/testblas.hpp\"" + NL - result += "#include \"test/routines/level" + level_string + "/x" + routine.name + ".hpp\"" + NL + NL + result += "#include \"test/routines/level" + level_string + "/x" + routine.lowercase_name() + ".hpp\"" + NL + NL result += "// Shortcuts to the clblast namespace" + NL result += "using float2 = clblast::float2;" + NL result += "using double2 = clblast::double2;" + NL + NL @@ -328,8 +332,8 @@ def correctness_test(routine, level_string): result += " auto errors = size_t{0};" + NL not_first = "false" for flavour in routine.flavours: - result += " errors += clblast::RunTests<clblast::TestX" + routine.name + flavour.test_template() - result += ">(argc, argv, " + not_first + ", \"" + flavour.name + routine.name.upper() + "\");" + NL + result += " errors += clblast::RunTests<clblast::TestX" + routine.plain_name() + flavour.test_template() + result += ">(argc, argv, " + not_first + ", \"" + flavour.name + routine.upper_name() + "\");" + NL not_first = "true" result += " if (errors > 0) { return 1; } else { return 0; }" + NL result += "}" + NL |