summaryrefslogtreecommitdiff
path: root/scripts/database
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-09-06 21:50:42 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-09-06 21:50:42 +0200
commit20da5e33a86eda746c17cbdb7bfd295d9f92f074 (patch)
treed35e7091ddc8bbd81d581c4bd49468c6329111fd /scripts/database
parentbb947890dec90712c92028c20234eafd48e6fa3e (diff)
Split the database files over multiple directories and files; first step towards separate compilation
Diffstat (limited to 'scripts/database')
-rw-r--r--scripts/database/database/clblast.py72
1 files changed, 45 insertions, 27 deletions
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py
index 803d1d2a..779dd76c 100644
--- a/scripts/database/database/clblast.py
+++ b/scripts/database/database/clblast.py
@@ -42,20 +42,19 @@ def get_cpp_separator():
return "// ================================================================================================="
-def get_cpp_header(family):
+def get_cpp_header(family, precision):
"""Retrieves the C++ header"""
return ("\n" + get_cpp_separator() + """
-// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
-// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
-// width of 100 characters per line.
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. It
+// is auto-generated by the 'scripts/database/database.py' Python script.
//
-// Author(s):
-// Database generator <database.py>
-//
-// This file populates the database with best-found tuning parameters for the '%s' kernels.
+// This file populates the database with best-found tuning parameters for the '%s%s' kernels.
//\n"""
- % family.title() + get_cpp_separator() + \
- "\n\nnamespace clblast {\n" + "namespace database {\n" + get_cpp_separator())
+ % (family.title(), precision)) + get_cpp_separator() + "\n"
+
+
+def get_cpp_header_namespace():
+ return "\nnamespace clblast {\n" + "namespace database {\n"
def get_cpp_footer():
@@ -67,7 +66,7 @@ def get_cpp_precision(family, precision):
"""Retrieves the C++ code for the start of a new precision"""
precision_string = precision_to_string(precision)
camelcase_name = family.title().replace("_", "")
- return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s"
+ return("\nconst DatabaseEntry %s%s = {\n \"%s\", Precision::k%s"
% (camelcase_name, precision_string, camelcase_name, precision_string))
@@ -79,6 +78,15 @@ def get_cpp_device_vendor(vendor, device_type):
return " { // %s %ss\n kDeviceType%s, \"%s\", {\n" % (vendor, device_type, device_type_caps, vendor)
+def get_cpp_family_includes(family, precisions):
+ result = "\n"
+ # result += "#include \"clblast.h\"\n"
+ # result += "#include \"database/database_structure.hpp\"\n"
+ for precision in precisions:
+ result += "#include \"database/kernels/%s/%s_%s.hpp\"\n" % (family, family, precision)
+ return result
+
+
def print_cpp_database(database, output_dir):
"""Outputs the database as C++ code"""
@@ -87,19 +95,23 @@ def print_cpp_database(database, output_dir):
for family_name in kernel_families:
family_database = [s for s in database["sections"] if s["kernel_family"] == family_name]
- # Opens a new file for each kernel family
- full_path = os.path.join(output_dir, family_name + ".hpp")
- with open(full_path, 'w+') as f:
- f.write(get_cpp_header(family_name))
+ # Goes into a new path for each kernel family
+ family_path = os.path.join(output_dir, family_name)
+
+ # Loops over the different precision (e.g. 16, 32, 3232, 64, 6464)
+ precisions = sorted(set([s["precision"] for s in database["sections"]])) # Based on full database
+ for precision in precisions:
+ precision_database = [s for s in family_database if s["precision"] == precision]
- # Loops over the different precision (e.g. 16, 32, 3232, 64, 6464)
- precisions = sorted(set([s["precision"] for s in database["sections"]])) # Based on full database
- for precision in precisions:
- precision_database = [s for s in family_database if s["precision"] == precision]
+ # Opens a new file for each precision
+ full_path = os.path.join(family_path, family_name + "_" + precision + ".hpp")
+ with open(full_path, 'w+') as f:
+ f.write(get_cpp_header(family_name, precision))
+ f.write(get_cpp_header_namespace())
f.write(get_cpp_precision(family_name, precision))
- # In case there is nothing found at all (e.g. 16-bit): continue as if this was a precision of 32 but
- # with the defaults only
+ # In case there is nothing found at all (e.g. 16-bit): continue as if this was a
+ # precision of 32 but with the defaults only
if len(precision_database) == 0:
print("[database] No results found for %s:%s, retrieving defaults from %s:32" %
(family_name, precision, family_name))
@@ -138,7 +150,7 @@ def print_cpp_database(database, output_dir):
# Collects the parameters for this entry
parameters = []
- parmameter_index = 0
+ parameter_index = 0
kernels = sorted(set([s["kernel"] for s in device_database]))
for kernel in kernels:
kernel_database = [s for s in device_database if s["kernel"] == kernel]
@@ -149,10 +161,10 @@ def print_cpp_database(database, output_dir):
assert len(results) == 1
new_parameters = results[0]["parameters"]
for parameter_name in sorted(new_parameters):
- assert parameter_name == parameter_names[parmameter_index]
+ assert parameter_name == parameter_names[parameter_index]
parameter_value = new_parameters[parameter_name]
parameters.append(str(parameter_value))
- parmameter_index += 1
+ parameter_index += 1
# Prints the entry
f.write(", ".join(parameters))
@@ -162,7 +174,13 @@ def print_cpp_database(database, output_dir):
f.write(" }\n },\n")
# Prints the precision footer
- f.write(" }\n};\n\n" + get_cpp_separator())
+ f.write(" }\n};\n")
+
+ # Prints the file footer
+ f.write(get_cpp_footer())
- # Prints the file footer
- f.write(get_cpp_footer())
+ # Creates the combined family includes header
+ full_path = os.path.join(family_path, family_name + ".hpp")
+ with open(full_path, 'w+') as f:
+ f.write(get_cpp_header(family_name, ""))
+ f.write(get_cpp_family_includes(family_name, precisions))