summaryrefslogtreecommitdiff
path: root/scripts/generator/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator.py')
-rw-r--r--scripts/generator/generator.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 93ff8680..382d728a 100644
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -103,7 +103,17 @@ routines = [
]]
# ==================================================================================================
+# Translates an option name to a CLBlast data-type
+def PrecisionToFullName(x):
+ return {
+ 'H': "Half",
+ 'S': "Single",
+ 'D': "Double",
+ 'C': "ComplexSingle",
+ 'Z': "ComplexDouble",
+ }[x]
+# ==================================================================================================
# Separators for the BLAS levels
separators = ["""
// =================================================================================================
@@ -237,7 +247,7 @@ files = [
path_clblast+"/src/clblast_c.cc",
path_clblast+"/test/wrapper_clblas.h",
]
-header_lines = [84, 52, 80, 24, 22]
+header_lines = [84, 55, 80, 24, 22]
footer_lines = [6, 3, 5, 2, 6]
# Checks whether the command-line arguments are valid; exists otherwise
@@ -315,16 +325,10 @@ for level in [1,2,3]:
body += "using double2 = clblast::double2;\n\n"
body += "// Main function (not within the clblast namespace)\n"
body += "int main(int argc, char *argv[]) {\n"
- body += " switch(clblast::GetPrecision(argc, argv)) {\n"
+ default = PrecisionToFullName(routine.flavours[0].name)
+ body += " switch(clblast::GetPrecision(argc, argv, clblast::Precision::k"+default+")) {\n"
for precision in ["H","S","D","C","Z"]:
- enum = {
- 'H': "Half",
- 'S': "Single",
- 'D': "Double",
- 'C': "ComplexSingle",
- 'Z': "ComplexDouble",
- }[precision]
- body += " case clblast::Precision::k"+enum+":"
+ body += " case clblast::Precision::k"+PrecisionToFullName(precision)+":"
found = False
for flavour in routine.flavours:
if flavour.name == precision: