diff options
Diffstat (limited to 'scripts/database/database.py')
-rwxr-xr-x | scripts/database/database.py | 29 |
1 files changed, 10 insertions, 19 deletions
diff --git a/scripts/database/database.py b/scripts/database/database.py index e398aa30..8f3ccce6 100755 --- a/scripts/database/database.py +++ b/scripts/database/database.py @@ -20,14 +20,6 @@ import database.defaults as defaults # Server storing a copy of the database DATABASE_SERVER_URL = "https://raw.githubusercontent.com/CNugteren/CLBlast-database/master/database.json" -# OpenCL vendor names and their short name -VENDOR_TRANSLATION_TABLE = { - "GenuineIntel": "Intel", - "Intel(R) Corporation": "Intel", - "Advanced Micro Devices, Inc.": "AMD", - "NVIDIA Corporation": "NVIDIA", -} - def remove_mismatched_arguments(database): """Checks for tuning results with mis-matched entries and removes them according to user preferences""" @@ -44,12 +36,14 @@ def remove_mismatched_arguments(database): for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes): group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES) if len(group_by_arguments) != 1: - print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name)) - print("[database] Either quit now, or remove all but one of the argument combinations below:") + print("[database] WARNING: entries for a single kernel with multiple argument values " + + str(kernel_group_name)) + print("[database] Either quit or remove all but one of the argument combinations below:") for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments): print("[database] %d: %s" % (index, attribute_group_name)) for attribute_group_name, mismatching_entries in group_by_arguments: - response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name)) + response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % + str(attribute_group_name)) if response == "y": for entry in mismatching_entries: database["sections"].remove(entry) @@ -59,7 +53,8 @@ def remove_mismatched_arguments(database): for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes): group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES) if len(group_by_arguments) != 1: - print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name)) + print("[database] ERROR: entries for a single kernel with multiple argument values " + + str(kernel_group_name)) assert len(group_by_arguments) == 1 @@ -97,7 +92,8 @@ def main(argv): # Checks whether the command-line arguments are valid clblast_header = os.path.join(cl_args.clblast_root, "include", "clblast.h") # Not used but just for validation if not os.path.isfile(clblast_header): - raise RuntimeError("The path '" + cl_args.clblast_root + "' does not point to the root of the CLBlast library") + raise RuntimeError("The path '" + cl_args.clblast_root + + "' does not point to the root of the CLBlast library") if len(glob.glob(json_files)) < 1: print("[database] The path '" + cl_args.source_folder + "' does not contain any JSON files") @@ -115,11 +111,6 @@ def main(argv): sys.stdout.write("[database] Processing '" + file_json + "' ") # No newline printed imported_data = io.load_tuning_results(file_json) - # Fixes the problem that some vendors use multiple different names - for target in VENDOR_TRANSLATION_TABLE: - if imported_data["device_vendor"] == target: - imported_data["device_vendor"] = VENDOR_TRANSLATION_TABLE[target] - # Adds the new data to the database old_size = db.length(database) database = db.add_section(database, imported_data) @@ -136,7 +127,7 @@ def main(argv): # Removes database entries before continuing if cl_args.remove_device is not None: print("[database] Removing all results for device '%s'" % cl_args.remove_device) - remove_database_entries(database, {"device": cl_args.remove_device}) + remove_database_entries(database, {"clblast_device": cl_args.remove_device}) io.save_database(database, database_filename) # Retrieves the best performing results |