diff options
Diffstat (limited to 'scripts/generator/generator.py')
-rw-r--r-- | scripts/generator/generator.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 93ff8680..382d728a 100644 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -103,7 +103,17 @@ routines = [ ]] # ================================================================================================== +# Translates an option name to a CLBlast data-type +def PrecisionToFullName(x): + return { + 'H': "Half", + 'S': "Single", + 'D': "Double", + 'C': "ComplexSingle", + 'Z': "ComplexDouble", + }[x] +# ================================================================================================== # Separators for the BLAS levels separators = [""" // ================================================================================================= @@ -237,7 +247,7 @@ files = [ path_clblast+"/src/clblast_c.cc", path_clblast+"/test/wrapper_clblas.h", ] -header_lines = [84, 52, 80, 24, 22] +header_lines = [84, 55, 80, 24, 22] footer_lines = [6, 3, 5, 2, 6] # Checks whether the command-line arguments are valid; exists otherwise @@ -315,16 +325,10 @@ for level in [1,2,3]: body += "using double2 = clblast::double2;\n\n" body += "// Main function (not within the clblast namespace)\n" body += "int main(int argc, char *argv[]) {\n" - body += " switch(clblast::GetPrecision(argc, argv)) {\n" + default = PrecisionToFullName(routine.flavours[0].name) + body += " switch(clblast::GetPrecision(argc, argv, clblast::Precision::k"+default+")) {\n" for precision in ["H","S","D","C","Z"]: - enum = { - 'H': "Half", - 'S': "Single", - 'D': "Double", - 'C': "ComplexSingle", - 'Z': "ComplexDouble", - }[precision] - body += " case clblast::Precision::k"+enum+":" + body += " case clblast::Precision::k"+PrecisionToFullName(precision)+":" found = False for flavour in routine.flavours: if flavour.name == precision: |