From 3ec14df60e8784f015556bfd65c78536b4eed4b7 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Mon, 17 Apr 2017 15:00:45 +0200 Subject: Added proper handling of mismatched arguments in the database script --- scripts/database/database.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) (limited to 'scripts/database/database.py') diff --git a/scripts/database/database.py b/scripts/database/database.py index 31f313da..f1af634b 100755 --- a/scripts/database/database.py +++ b/scripts/database/database.py @@ -29,6 +29,40 @@ VENDOR_TRANSLATION_TABLE = { } +def remove_mismatched_arguments(database): + """Checks for tuning results with mis-matched entries and removes them according to user preferences""" + kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"] + + # For Python 2 and 3 compatibility + try: + user_input = raw_input + except NameError: + user_input = input + pass + + # Check for mis-matched entries + for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes): + group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES) + if len(group_by_arguments) != 1: + print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name)) + print("[database] Either quit now, or remove all but one of the argument combinations below:") + for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments): + print("[database] %d: %s" % (index, attribute_group_name)) + for attribute_group_name, mismatching_entries in group_by_arguments: + response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name)) + if response == "y": + for entry in mismatching_entries: + database["sections"].remove(entry) + print("[database] Removed %d entry/entries" % len(mismatching_entries)) + + # Sanity-check: all mis-matched entries should be removed + for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes): + group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES) + if len(group_by_arguments) != 1: + print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name)) + assert len(group_by_arguments) == 1 + + def main(argv): # Parses the command-line arguments @@ -76,6 +110,9 @@ def main(argv): new_size = db.length(database) print("with " + str(new_size - old_size) + " new items") # Newline printed here + # Checks for tuning results with mis-matched entries + remove_mismatched_arguments(database) + # Stores the modified database back to disk if len(glob.glob(json_files)) >= 1: io.save_database(database, database_filename) -- cgit v1.2.3