diff options
author | CNugteren <web@cedricnugteren.nl> | 2015-09-18 10:19:03 +0200 |
---|---|---|
committer | CNugteren <web@cedricnugteren.nl> | 2015-09-18 10:19:03 +0200 |
commit | 4796c9bcbd84a9e8be1e2864ba47e0d6bf3e6632 (patch) | |
tree | d74289f024341071ec11324a9827e2e33428f895 /scripts/generator | |
parent | 6105ad6f5b40b319477be7b51b8631e510d58672 (diff) |
Added generated main functions for correctness/performance tests for level 2 routines
Diffstat (limited to 'scripts/generator')
-rw-r--r-- | scripts/generator/datatype.py | 9 | ||||
-rw-r--r-- | scripts/generator/generator.py | 99 |
2 files changed, 99 insertions, 9 deletions
diff --git a/scripts/generator/datatype.py b/scripts/generator/datatype.py index cca3534d..0aa27197 100644 --- a/scripts/generator/datatype.py +++ b/scripts/generator/datatype.py @@ -29,7 +29,7 @@ class DataType(): self.beta_cpp = scalars[1] self.alpha_cl = scalars[2] self.beta_cl = scalars[3] - self.buffertype = buffertype # Only used for template types + self.buffertype = buffertype # Outputs the name of the data-type (alpha/beta), possibly transforming into the right type def UseAlpha(self): @@ -51,4 +51,11 @@ class DataType(): return self.beta_cl+"{{beta.real(), beta.imag()}}" return "beta" + # Returns the template as used in the correctness/performance tests + def TestTemplate(self): + if self.buffertype != self.beta_cpp: + return "<"+self.buffertype+","+self.beta_cpp+">, "+self.buffertype+", "+self.beta_cpp + return "<"+self.buffertype+">, "+self.buffertype+", "+self.beta_cpp + + # ================================================================================================== diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 9c9675b8..677c8afc 100644 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -14,6 +14,9 @@ # clblast_c.h # clblast_c.cc # wrapper_clblas.h +# It also generates the main functions for the correctness and performance tests as found in +# test/correctness/routines/levelX/xYYYY.cc +# test/performance/routines/levelX/xYYYY.cc # # ================================================================================================== @@ -30,12 +33,12 @@ from datatype import DataType, FLT, DBL, FLT2, DBL2, F2CL, D2CL # Regular data-types S = DataType("S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32) D = DataType("D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64) -C = DataType("C", FLT2, [FLT2, FLT2, F2CL, F2CL], F2CL) # single-complex (3232) -Z = DataType("Z", DBL2, [DBL2, DBL2, D2CL, D2CL], D2CL) # double-complex (6464) +C = DataType("C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232) +Z = DataType("Z", DBL2, [DBL2, DBL2, D2CL, D2CL], DBL2) # double-complex (6464) # Special cases -Css = DataType("C", FLT, [FLT, FLT, FLT, FLT], FLT ) # As C, but with constants from S -Zdd = DataType("Z", DBL, [DBL, DBL, DBL, DBL], DBL ) # As Z, but with constants from D +Css = DataType("C", FLT, [FLT, FLT, FLT, FLT], FLT2) # As C, but with constants from S +Zdd = DataType("Z", DBL, [DBL, DBL, DBL, DBL], DBL2) # As Z, but with constants from D Ccs = DataType("C", FLT2+","+FLT, [FLT2, FLT, F2CL, FLT], FLT2) # As C, but with one constant from S Zzd = DataType("Z", DBL2+","+DBL, [DBL2, DBL, D2CL, DBL], DBL2) # As Z, but with one constant from D @@ -115,6 +118,22 @@ separators = [""" // BLAS level-3 (matrix-matrix) routines // ================================================================================================="""] +# Main header/footer for source files +header = """ +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// ================================================================================================= +""" +footer = """ +// ================================================================================================= +""" + # ================================================================================================== # The C++ API header (.h) @@ -235,8 +254,8 @@ for i in xrange(0,len(files)): # Stores the header and the footer of the original file with open(files[i]) as f: original = f.readlines() - header = original[:header_lines[i]] - footer = original[-footer_lines[i]:] + file_header = original[:header_lines[i]] + file_footer = original[-footer_lines[i]:] # Re-writes the body of the file with open(files[i], "w") as f: @@ -253,8 +272,72 @@ for i in xrange(0,len(files)): body += clblast_c_cc(routines[level-1]) if i == 4: body += wrapper_clblas(routines[level-1]) - f.write("".join(header)) + f.write("".join(file_header)) f.write(body) - f.write("".join(footer)) + f.write("".join(file_footer)) + +# ================================================================================================== + +# Outputs all the correctness-test implementations +for level in [1,2,3]: + for routine in routines[level-1]: + filename = path_clblast+"/test/correctness/routines/level"+str(level)+"/x"+routine.name+".cc" + with open(filename, "w") as f: + body = "" + body += "#include \"correctness/testblas.h\"\n" + body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n" + body += "// Shortcuts to the clblast namespace\n" + body += "using float2 = clblast::float2;\n" + body += "using double2 = clblast::double2;\n\n" + body += "// Main function (not within the clblast namespace)\n" + body += "int main(int argc, char *argv[]) {\n" + not_first = "false" + for flavour in routine.flavours: + body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate() + body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n" + not_first = "true" + body += " return 0;\n" + body += "}\n" + f.write(header+"\n") + f.write(body) + f.write(footer) + +# Outputs all the performance-test implementations +for level in [1,2,3]: + for routine in routines[level-1]: + filename = path_clblast+"/test/performance/routines/level"+str(level)+"/x"+routine.name+".cc" + with open(filename, "w") as f: + body = "" + body += "#include \"performance/client.h\"\n" + body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n" + body += "// Shortcuts to the clblast namespace\n" + body += "using float2 = clblast::float2;\n" + body += "using double2 = clblast::double2;\n\n" + body += "// Main function (not within the clblast namespace)\n" + body += "int main(int argc, char *argv[]) {\n" + body += " switch(clblast::GetPrecision(argc, argv)) {\n" + for precision in ["H","S","D","C","Z"]: + enum = { + 'H': "Half", + 'S': "Single", + 'D': "Double", + 'C': "ComplexSingle", + 'Z': "ComplexDouble", + }[precision] + body += " case clblast::Precision::k"+enum+":" + found = False + for flavour in routine.flavours: + if flavour.name == precision: + body += "\n clblast::RunClient<clblast::TestX"+routine.name+flavour.TestTemplate() + body += ">(argc, argv); break;\n" + found = True + if not found: + body += " throw std::runtime_error(\"Unsupported precision mode\");\n" + body += " }\n" + body += " return 0;\n" + body += "}\n" + f.write(header+"\n") + f.write(body) + f.write(footer) # ================================================================================================== |