summaryrefslogtreecommitdiff
path: root/scripts/generator
diff options
context:
space:
mode:
authorCNugteren <web@cedricnugteren.nl>2015-09-18 10:19:03 +0200
committerCNugteren <web@cedricnugteren.nl>2015-09-18 10:19:03 +0200
commit4796c9bcbd84a9e8be1e2864ba47e0d6bf3e6632 (patch)
treed74289f024341071ec11324a9827e2e33428f895 /scripts/generator
parent6105ad6f5b40b319477be7b51b8631e510d58672 (diff)
Added generated main functions for correctness/performance tests for level 2 routines
Diffstat (limited to 'scripts/generator')
-rw-r--r--scripts/generator/datatype.py9
-rw-r--r--scripts/generator/generator.py99
2 files changed, 99 insertions, 9 deletions
diff --git a/scripts/generator/datatype.py b/scripts/generator/datatype.py
index cca3534d..0aa27197 100644
--- a/scripts/generator/datatype.py
+++ b/scripts/generator/datatype.py
@@ -29,7 +29,7 @@ class DataType():
self.beta_cpp = scalars[1]
self.alpha_cl = scalars[2]
self.beta_cl = scalars[3]
- self.buffertype = buffertype # Only used for template types
+ self.buffertype = buffertype
# Outputs the name of the data-type (alpha/beta), possibly transforming into the right type
def UseAlpha(self):
@@ -51,4 +51,11 @@ class DataType():
return self.beta_cl+"{{beta.real(), beta.imag()}}"
return "beta"
+ # Returns the template as used in the correctness/performance tests
+ def TestTemplate(self):
+ if self.buffertype != self.beta_cpp:
+ return "<"+self.buffertype+","+self.beta_cpp+">, "+self.buffertype+", "+self.beta_cpp
+ return "<"+self.buffertype+">, "+self.buffertype+", "+self.beta_cpp
+
+
# ==================================================================================================
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 9c9675b8..677c8afc 100644
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -14,6 +14,9 @@
# clblast_c.h
# clblast_c.cc
# wrapper_clblas.h
+# It also generates the main functions for the correctness and performance tests as found in
+# test/correctness/routines/levelX/xYYYY.cc
+# test/performance/routines/levelX/xYYYY.cc
#
# ==================================================================================================
@@ -30,12 +33,12 @@ from datatype import DataType, FLT, DBL, FLT2, DBL2, F2CL, D2CL
# Regular data-types
S = DataType("S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32)
D = DataType("D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64)
-C = DataType("C", FLT2, [FLT2, FLT2, F2CL, F2CL], F2CL) # single-complex (3232)
-Z = DataType("Z", DBL2, [DBL2, DBL2, D2CL, D2CL], D2CL) # double-complex (6464)
+C = DataType("C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232)
+Z = DataType("Z", DBL2, [DBL2, DBL2, D2CL, D2CL], DBL2) # double-complex (6464)
# Special cases
-Css = DataType("C", FLT, [FLT, FLT, FLT, FLT], FLT ) # As C, but with constants from S
-Zdd = DataType("Z", DBL, [DBL, DBL, DBL, DBL], DBL ) # As Z, but with constants from D
+Css = DataType("C", FLT, [FLT, FLT, FLT, FLT], FLT2) # As C, but with constants from S
+Zdd = DataType("Z", DBL, [DBL, DBL, DBL, DBL], DBL2) # As Z, but with constants from D
Ccs = DataType("C", FLT2+","+FLT, [FLT2, FLT, F2CL, FLT], FLT2) # As C, but with one constant from S
Zzd = DataType("Z", DBL2+","+DBL, [DBL2, DBL, D2CL, DBL], DBL2) # As Z, but with one constant from D
@@ -115,6 +118,22 @@ separators = ["""
// BLAS level-3 (matrix-matrix) routines
// ================================================================================================="""]
+# Main header/footer for source files
+header = """
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// =================================================================================================
+"""
+footer = """
+// =================================================================================================
+"""
+
# ==================================================================================================
# The C++ API header (.h)
@@ -235,8 +254,8 @@ for i in xrange(0,len(files)):
# Stores the header and the footer of the original file
with open(files[i]) as f:
original = f.readlines()
- header = original[:header_lines[i]]
- footer = original[-footer_lines[i]:]
+ file_header = original[:header_lines[i]]
+ file_footer = original[-footer_lines[i]:]
# Re-writes the body of the file
with open(files[i], "w") as f:
@@ -253,8 +272,72 @@ for i in xrange(0,len(files)):
body += clblast_c_cc(routines[level-1])
if i == 4:
body += wrapper_clblas(routines[level-1])
- f.write("".join(header))
+ f.write("".join(file_header))
f.write(body)
- f.write("".join(footer))
+ f.write("".join(file_footer))
+
+# ==================================================================================================
+
+# Outputs all the correctness-test implementations
+for level in [1,2,3]:
+ for routine in routines[level-1]:
+ filename = path_clblast+"/test/correctness/routines/level"+str(level)+"/x"+routine.name+".cc"
+ with open(filename, "w") as f:
+ body = ""
+ body += "#include \"correctness/testblas.h\"\n"
+ body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n"
+ body += "// Shortcuts to the clblast namespace\n"
+ body += "using float2 = clblast::float2;\n"
+ body += "using double2 = clblast::double2;\n\n"
+ body += "// Main function (not within the clblast namespace)\n"
+ body += "int main(int argc, char *argv[]) {\n"
+ not_first = "false"
+ for flavour in routine.flavours:
+ body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
+ body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
+ not_first = "true"
+ body += " return 0;\n"
+ body += "}\n"
+ f.write(header+"\n")
+ f.write(body)
+ f.write(footer)
+
+# Outputs all the performance-test implementations
+for level in [1,2,3]:
+ for routine in routines[level-1]:
+ filename = path_clblast+"/test/performance/routines/level"+str(level)+"/x"+routine.name+".cc"
+ with open(filename, "w") as f:
+ body = ""
+ body += "#include \"performance/client.h\"\n"
+ body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n"
+ body += "// Shortcuts to the clblast namespace\n"
+ body += "using float2 = clblast::float2;\n"
+ body += "using double2 = clblast::double2;\n\n"
+ body += "// Main function (not within the clblast namespace)\n"
+ body += "int main(int argc, char *argv[]) {\n"
+ body += " switch(clblast::GetPrecision(argc, argv)) {\n"
+ for precision in ["H","S","D","C","Z"]:
+ enum = {
+ 'H': "Half",
+ 'S': "Single",
+ 'D': "Double",
+ 'C': "ComplexSingle",
+ 'Z': "ComplexDouble",
+ }[precision]
+ body += " case clblast::Precision::k"+enum+":"
+ found = False
+ for flavour in routine.flavours:
+ if flavour.name == precision:
+ body += "\n clblast::RunClient<clblast::TestX"+routine.name+flavour.TestTemplate()
+ body += ">(argc, argv); break;\n"
+ found = True
+ if not found:
+ body += " throw std::runtime_error(\"Unsupported precision mode\");\n"
+ body += " }\n"
+ body += " return 0;\n"
+ body += "}\n"
+ f.write(header+"\n")
+ f.write(body)
+ f.write(footer)
# ==================================================================================================