summaryrefslogtreecommitdiff
path: root/scripts/generator/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator.py')
-rw-r--r--scripts/generator/generator.py99
1 files changed, 91 insertions, 8 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 9c9675b8..677c8afc 100644
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -14,6 +14,9 @@
# clblast_c.h
# clblast_c.cc
# wrapper_clblas.h
+# It also generates the main functions for the correctness and performance tests as found in
+# test/correctness/routines/levelX/xYYYY.cc
+# test/performance/routines/levelX/xYYYY.cc
#
# ==================================================================================================
@@ -30,12 +33,12 @@ from datatype import DataType, FLT, DBL, FLT2, DBL2, F2CL, D2CL
# Regular data-types
S = DataType("S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32)
D = DataType("D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64)
-C = DataType("C", FLT2, [FLT2, FLT2, F2CL, F2CL], F2CL) # single-complex (3232)
-Z = DataType("Z", DBL2, [DBL2, DBL2, D2CL, D2CL], D2CL) # double-complex (6464)
+C = DataType("C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232)
+Z = DataType("Z", DBL2, [DBL2, DBL2, D2CL, D2CL], DBL2) # double-complex (6464)
# Special cases
-Css = DataType("C", FLT, [FLT, FLT, FLT, FLT], FLT ) # As C, but with constants from S
-Zdd = DataType("Z", DBL, [DBL, DBL, DBL, DBL], DBL ) # As Z, but with constants from D
+Css = DataType("C", FLT, [FLT, FLT, FLT, FLT], FLT2) # As C, but with constants from S
+Zdd = DataType("Z", DBL, [DBL, DBL, DBL, DBL], DBL2) # As Z, but with constants from D
Ccs = DataType("C", FLT2+","+FLT, [FLT2, FLT, F2CL, FLT], FLT2) # As C, but with one constant from S
Zzd = DataType("Z", DBL2+","+DBL, [DBL2, DBL, D2CL, DBL], DBL2) # As Z, but with one constant from D
@@ -115,6 +118,22 @@ separators = ["""
// BLAS level-3 (matrix-matrix) routines
// ================================================================================================="""]
+# Main header/footer for source files
+header = """
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// =================================================================================================
+"""
+footer = """
+// =================================================================================================
+"""
+
# ==================================================================================================
# The C++ API header (.h)
@@ -235,8 +254,8 @@ for i in xrange(0,len(files)):
# Stores the header and the footer of the original file
with open(files[i]) as f:
original = f.readlines()
- header = original[:header_lines[i]]
- footer = original[-footer_lines[i]:]
+ file_header = original[:header_lines[i]]
+ file_footer = original[-footer_lines[i]:]
# Re-writes the body of the file
with open(files[i], "w") as f:
@@ -253,8 +272,72 @@ for i in xrange(0,len(files)):
body += clblast_c_cc(routines[level-1])
if i == 4:
body += wrapper_clblas(routines[level-1])
- f.write("".join(header))
+ f.write("".join(file_header))
f.write(body)
- f.write("".join(footer))
+ f.write("".join(file_footer))
+
+# ==================================================================================================
+
+# Outputs all the correctness-test implementations
+for level in [1,2,3]:
+ for routine in routines[level-1]:
+ filename = path_clblast+"/test/correctness/routines/level"+str(level)+"/x"+routine.name+".cc"
+ with open(filename, "w") as f:
+ body = ""
+ body += "#include \"correctness/testblas.h\"\n"
+ body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n"
+ body += "// Shortcuts to the clblast namespace\n"
+ body += "using float2 = clblast::float2;\n"
+ body += "using double2 = clblast::double2;\n\n"
+ body += "// Main function (not within the clblast namespace)\n"
+ body += "int main(int argc, char *argv[]) {\n"
+ not_first = "false"
+ for flavour in routine.flavours:
+ body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
+ body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
+ not_first = "true"
+ body += " return 0;\n"
+ body += "}\n"
+ f.write(header+"\n")
+ f.write(body)
+ f.write(footer)
+
+# Outputs all the performance-test implementations
+for level in [1,2,3]:
+ for routine in routines[level-1]:
+ filename = path_clblast+"/test/performance/routines/level"+str(level)+"/x"+routine.name+".cc"
+ with open(filename, "w") as f:
+ body = ""
+ body += "#include \"performance/client.h\"\n"
+ body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n"
+ body += "// Shortcuts to the clblast namespace\n"
+ body += "using float2 = clblast::float2;\n"
+ body += "using double2 = clblast::double2;\n\n"
+ body += "// Main function (not within the clblast namespace)\n"
+ body += "int main(int argc, char *argv[]) {\n"
+ body += " switch(clblast::GetPrecision(argc, argv)) {\n"
+ for precision in ["H","S","D","C","Z"]:
+ enum = {
+ 'H': "Half",
+ 'S': "Single",
+ 'D': "Double",
+ 'C': "ComplexSingle",
+ 'Z': "ComplexDouble",
+ }[precision]
+ body += " case clblast::Precision::k"+enum+":"
+ found = False
+ for flavour in routine.flavours:
+ if flavour.name == precision:
+ body += "\n clblast::RunClient<clblast::TestX"+routine.name+flavour.TestTemplate()
+ body += ">(argc, argv); break;\n"
+ found = True
+ if not found:
+ body += " throw std::runtime_error(\"Unsupported precision mode\");\n"
+ body += " }\n"
+ body += " return 0;\n"
+ body += "}\n"
+ f.write(header+"\n")
+ f.write(body)
+ f.write(footer)
# ==================================================================================================