summaryrefslogtreecommitdiff
path: root/scripts/database/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/database/database.py')
-rwxr-xr-xscripts/database/database.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/scripts/database/database.py b/scripts/database/database.py
index 31f313da..f1af634b 100755
--- a/scripts/database/database.py
+++ b/scripts/database/database.py
@@ -29,6 +29,40 @@ VENDOR_TRANSLATION_TABLE = {
}
+def remove_mismatched_arguments(database):
+ """Checks for tuning results with mis-matched entries and removes them according to user preferences"""
+ kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"]
+
+ # For Python 2 and 3 compatibility
+ try:
+ user_input = raw_input
+ except NameError:
+ user_input = input
+ pass
+
+ # Check for mis-matched entries
+ for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
+ group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
+ if len(group_by_arguments) != 1:
+ print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name))
+ print("[database] Either quit now, or remove all but one of the argument combinations below:")
+ for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments):
+ print("[database] %d: %s" % (index, attribute_group_name))
+ for attribute_group_name, mismatching_entries in group_by_arguments:
+ response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name))
+ if response == "y":
+ for entry in mismatching_entries:
+ database["sections"].remove(entry)
+ print("[database] Removed %d entry/entries" % len(mismatching_entries))
+
+ # Sanity-check: all mis-matched entries should be removed
+ for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
+ group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
+ if len(group_by_arguments) != 1:
+ print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name))
+ assert len(group_by_arguments) == 1
+
+
def main(argv):
# Parses the command-line arguments
@@ -76,6 +110,9 @@ def main(argv):
new_size = db.length(database)
print("with " + str(new_size - old_size) + " new items") # Newline printed here
+ # Checks for tuning results with mis-matched entries
+ remove_mismatched_arguments(database)
+
# Stores the modified database back to disk
if len(glob.glob(json_files)) >= 1:
io.save_database(database, database_filename)