summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/cpp.py
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-03-05 10:38:38 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-03-05 10:38:38 +0100
commitf9a520b3aff7b4eec99d9e11a03f9467e7ab351c (patch)
tree90612ac786448fa6e76681ecf6755f1c35c458a4 /scripts/generator/generator/cpp.py
parent37228c90988509acef9e8a892a752300b7645210 (diff)
Prepared generator for batched routines; added batched AXPY routine interface
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r--scripts/generator/generator/cpp.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index c14d00a1..91fdf458 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -51,8 +51,10 @@ def clblast_cc(routine):
result += routine.routine_header_cpp(12, "") + " {" + NL
result += " try {" + NL
result += " auto queue_cpp = Queue(*queue);" + NL
- result += " auto routine = X" + routine.name + "<" + routine.template.template + ">(queue_cpp, event);" + NL
- result += " routine.Do" + routine.name.capitalize() + "("
+ result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, event);" + NL
+ if routine.batched:
+ result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
+ result += " routine.Do" + routine.capitalized_name() + "("
result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()])
result += ");" + NL
result += " return StatusCode::kSuccess;" + NL
@@ -63,7 +65,7 @@ def clblast_cc(routine):
result += "}" + NL
for flavour in routine.flavours:
indent2 = " " * (34 + routine.length() + len(flavour.template))
- result += "template StatusCode PUBLIC_API " + routine.name.capitalize() + "<" + flavour.template + ">("
+ result += "template StatusCode PUBLIC_API " + routine.capitalized_name() + "<" + flavour.template + ">("
result += ("," + NL + indent2).join([a for a in routine.arguments_type(flavour)])
result += "," + NL + indent2 + "cl_command_queue*, cl_event*);" + NL
return result
@@ -84,9 +86,11 @@ def clblast_c_cc(routine):
template = "<" + flavour.template + ">" if routine.no_scalars() else ""
indent = " " * (16 + routine.length() + len(template))
result += routine.routine_header_c(flavour, 27, "") + " {" + NL
+ if routine.batched:
+ result += " " + (NL + " ").join(routine.batched_transform_to_complex(flavour)) + NL
result += " try {" + NL
result += " return static_cast<CLBlastStatusCode>(" + NL
- result += " clblast::" + routine.name.capitalize() + template + "("
+ result += " clblast::" + routine.capitalized_name() + template + "("
result += ("," + NL + indent).join([a for a in routine.arguments_cast(flavour, indent)])
result += "," + NL + indent + "queue, event)" + NL
result += " );" + NL
@@ -290,7 +294,7 @@ def performance_test(routine, level_string):
"""Generates the body of a performance test for a specific routine"""
result = ""
result += "#include \"test/performance/client.hpp\"" + NL
- result += "#include \"test/routines/level" + level_string + "/x" + routine.name + ".hpp\"" + NL + NL
+ result += "#include \"test/routines/level" + level_string + "/x" + routine.lowercase_name() + ".hpp\"" + NL + NL
result += "// Shortcuts to the clblast namespace" + NL
result += "using float2 = clblast::float2;" + NL
result += "using double2 = clblast::double2;" + NL + NL
@@ -304,7 +308,7 @@ def performance_test(routine, level_string):
found = False
for flavour in routine.flavours:
if flavour.precision_name == precision:
- result += NL + " clblast::RunClient<clblast::TestX" + routine.name + flavour.test_template()
+ result += NL + " clblast::RunClient<clblast::TestX" + routine.plain_name() + flavour.test_template()
result += ">(argc, argv); break;" + NL
found = True
if not found:
@@ -319,7 +323,7 @@ def correctness_test(routine, level_string):
"""Generates the body of a correctness test for a specific routine"""
result = ""
result += "#include \"test/correctness/testblas.hpp\"" + NL
- result += "#include \"test/routines/level" + level_string + "/x" + routine.name + ".hpp\"" + NL + NL
+ result += "#include \"test/routines/level" + level_string + "/x" + routine.lowercase_name() + ".hpp\"" + NL + NL
result += "// Shortcuts to the clblast namespace" + NL
result += "using float2 = clblast::float2;" + NL
result += "using double2 = clblast::double2;" + NL + NL
@@ -328,8 +332,8 @@ def correctness_test(routine, level_string):
result += " auto errors = size_t{0};" + NL
not_first = "false"
for flavour in routine.flavours:
- result += " errors += clblast::RunTests<clblast::TestX" + routine.name + flavour.test_template()
- result += ">(argc, argv, " + not_first + ", \"" + flavour.name + routine.name.upper() + "\");" + NL
+ result += " errors += clblast::RunTests<clblast::TestX" + routine.plain_name() + flavour.test_template()
+ result += ">(argc, argv, " + not_first + ", \"" + flavour.name + routine.upper_name() + "\");" + NL
not_first = "true"
result += " if (errors > 0) { return 1; } else { return 0; }" + NL
result += "}" + NL