diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/database/database/clblast.py | 72 |
1 files changed, 45 insertions, 27 deletions
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py index 803d1d2a..779dd76c 100644 --- a/scripts/database/database/clblast.py +++ b/scripts/database/database/clblast.py @@ -42,20 +42,19 @@ def get_cpp_separator(): return "// =================================================================================================" -def get_cpp_header(family): +def get_cpp_header(family, precision): """Retrieves the C++ header""" return ("\n" + get_cpp_separator() + """ -// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This -// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- -// width of 100 characters per line. +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. It +// is auto-generated by the 'scripts/database/database.py' Python script. // -// Author(s): -// Database generator <database.py> -// -// This file populates the database with best-found tuning parameters for the '%s' kernels. +// This file populates the database with best-found tuning parameters for the '%s%s' kernels. //\n""" - % family.title() + get_cpp_separator() + \ - "\n\nnamespace clblast {\n" + "namespace database {\n" + get_cpp_separator()) + % (family.title(), precision)) + get_cpp_separator() + "\n" + + +def get_cpp_header_namespace(): + return "\nnamespace clblast {\n" + "namespace database {\n" def get_cpp_footer(): @@ -67,7 +66,7 @@ def get_cpp_precision(family, precision): """Retrieves the C++ code for the start of a new precision""" precision_string = precision_to_string(precision) camelcase_name = family.title().replace("_", "") - return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s" + return("\nconst DatabaseEntry %s%s = {\n \"%s\", Precision::k%s" % (camelcase_name, precision_string, camelcase_name, precision_string)) @@ -79,6 +78,15 @@ def get_cpp_device_vendor(vendor, device_type): return " { // %s %ss\n kDeviceType%s, \"%s\", {\n" % (vendor, device_type, device_type_caps, vendor) +def get_cpp_family_includes(family, precisions): + result = "\n" + # result += "#include \"clblast.h\"\n" + # result += "#include \"database/database_structure.hpp\"\n" + for precision in precisions: + result += "#include \"database/kernels/%s/%s_%s.hpp\"\n" % (family, family, precision) + return result + + def print_cpp_database(database, output_dir): """Outputs the database as C++ code""" @@ -87,19 +95,23 @@ def print_cpp_database(database, output_dir): for family_name in kernel_families: family_database = [s for s in database["sections"] if s["kernel_family"] == family_name] - # Opens a new file for each kernel family - full_path = os.path.join(output_dir, family_name + ".hpp") - with open(full_path, 'w+') as f: - f.write(get_cpp_header(family_name)) + # Goes into a new path for each kernel family + family_path = os.path.join(output_dir, family_name) + + # Loops over the different precision (e.g. 16, 32, 3232, 64, 6464) + precisions = sorted(set([s["precision"] for s in database["sections"]])) # Based on full database + for precision in precisions: + precision_database = [s for s in family_database if s["precision"] == precision] - # Loops over the different precision (e.g. 16, 32, 3232, 64, 6464) - precisions = sorted(set([s["precision"] for s in database["sections"]])) # Based on full database - for precision in precisions: - precision_database = [s for s in family_database if s["precision"] == precision] + # Opens a new file for each precision + full_path = os.path.join(family_path, family_name + "_" + precision + ".hpp") + with open(full_path, 'w+') as f: + f.write(get_cpp_header(family_name, precision)) + f.write(get_cpp_header_namespace()) f.write(get_cpp_precision(family_name, precision)) - # In case there is nothing found at all (e.g. 16-bit): continue as if this was a precision of 32 but - # with the defaults only + # In case there is nothing found at all (e.g. 16-bit): continue as if this was a + # precision of 32 but with the defaults only if len(precision_database) == 0: print("[database] No results found for %s:%s, retrieving defaults from %s:32" % (family_name, precision, family_name)) @@ -138,7 +150,7 @@ def print_cpp_database(database, output_dir): # Collects the parameters for this entry parameters = [] - parmameter_index = 0 + parameter_index = 0 kernels = sorted(set([s["kernel"] for s in device_database])) for kernel in kernels: kernel_database = [s for s in device_database if s["kernel"] == kernel] @@ -149,10 +161,10 @@ def print_cpp_database(database, output_dir): assert len(results) == 1 new_parameters = results[0]["parameters"] for parameter_name in sorted(new_parameters): - assert parameter_name == parameter_names[parmameter_index] + assert parameter_name == parameter_names[parameter_index] parameter_value = new_parameters[parameter_name] parameters.append(str(parameter_value)) - parmameter_index += 1 + parameter_index += 1 # Prints the entry f.write(", ".join(parameters)) @@ -162,7 +174,13 @@ def print_cpp_database(database, output_dir): f.write(" }\n },\n") # Prints the precision footer - f.write(" }\n};\n\n" + get_cpp_separator()) + f.write(" }\n};\n") + + # Prints the file footer + f.write(get_cpp_footer()) - # Prints the file footer - f.write(get_cpp_footer()) + # Creates the combined family includes header + full_path = os.path.join(family_path, family_name + ".hpp") + with open(full_path, 'w+') as f: + f.write(get_cpp_header(family_name, "")) + f.write(get_cpp_family_includes(family_name, precisions)) |