diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-04-17 15:00:45 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-04-17 15:00:45 +0200 |
commit | 3ec14df60e8784f015556bfd65c78536b4eed4b7 (patch) | |
tree | cbec1ba326b4ec348ba18b0cc0942ef5c2aec16e | |
parent | b20c518f9fd05a69957c2018e72c6a648f5cdb7d (diff) |
Added proper handling of mismatched arguments in the database script
-rwxr-xr-x | scripts/database/database.py | 37 | ||||
-rw-r--r-- | scripts/database/database/db.py | 14 |
2 files changed, 51 insertions, 0 deletions
diff --git a/scripts/database/database.py b/scripts/database/database.py index 31f313da..f1af634b 100755 --- a/scripts/database/database.py +++ b/scripts/database/database.py @@ -29,6 +29,40 @@ VENDOR_TRANSLATION_TABLE = { } +def remove_mismatched_arguments(database): + """Checks for tuning results with mis-matched entries and removes them according to user preferences""" + kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"] + + # For Python 2 and 3 compatibility + try: + user_input = raw_input + except NameError: + user_input = input + pass + + # Check for mis-matched entries + for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes): + group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES) + if len(group_by_arguments) != 1: + print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name)) + print("[database] Either quit now, or remove all but one of the argument combinations below:") + for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments): + print("[database] %d: %s" % (index, attribute_group_name)) + for attribute_group_name, mismatching_entries in group_by_arguments: + response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name)) + if response == "y": + for entry in mismatching_entries: + database["sections"].remove(entry) + print("[database] Removed %d entry/entries" % len(mismatching_entries)) + + # Sanity-check: all mis-matched entries should be removed + for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes): + group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES) + if len(group_by_arguments) != 1: + print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name)) + assert len(group_by_arguments) == 1 + + def main(argv): # Parses the command-line arguments @@ -76,6 +110,9 @@ def main(argv): new_size = db.length(database) print("with " + str(new_size - old_size) + " new items") # Newline printed here + # Checks for tuning results with mis-matched entries + remove_mismatched_arguments(database) + # Stores the modified database back to disk if len(glob.glob(json_files)) >= 1: io.save_database(database, database_filename) diff --git a/scripts/database/database/db.py b/scripts/database/database/db.py index 94948b1a..51c9f1ec 100644 --- a/scripts/database/database/db.py +++ b/scripts/database/database/db.py @@ -5,6 +5,9 @@ # Author(s): # Cedric Nugteren <www.cedricnugteren.nl> +import itertools +from operator import itemgetter + import clblast @@ -62,3 +65,14 @@ def combine_result(old_results, new_result): # No match found: append a new result old_results.append(new_result) return old_results + + +def group_by(database, attributes): + """Returns an list with the name of the group and the corresponding entries in the database""" + assert len(database) > 0 + attributes = [a for a in attributes if a in database[0]] + database.sort(key=itemgetter(*attributes)) + result = [] + for key, data in itertools.groupby(database, key=itemgetter(*attributes)): + result.append((key, list(data))) + return result |