summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-04-21 21:59:48 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-04-21 21:59:48 +0200
commit957aaae6ca725564de74f7e7710d28adb9caedbb (patch)
tree8625af482ebbfe72eb720b7777e09a15c720d166 /scripts
parentcc9ad7b33b30f25f46b5091eaceea9994610f8e7 (diff)
parent3ec14df60e8784f015556bfd65c78536b4eed4b7 (diff)
Merge branch 'development' into benchmarking
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/database/database.py37
-rw-r--r--scripts/database/database/db.py14
2 files changed, 51 insertions, 0 deletions
diff --git a/scripts/database/database.py b/scripts/database/database.py
index 31f313da..f1af634b 100755
--- a/scripts/database/database.py
+++ b/scripts/database/database.py
@@ -29,6 +29,40 @@ VENDOR_TRANSLATION_TABLE = {
}
+def remove_mismatched_arguments(database):
+ """Checks for tuning results with mis-matched entries and removes them according to user preferences"""
+ kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"]
+
+ # For Python 2 and 3 compatibility
+ try:
+ user_input = raw_input
+ except NameError:
+ user_input = input
+ pass
+
+ # Check for mis-matched entries
+ for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
+ group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
+ if len(group_by_arguments) != 1:
+ print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name))
+ print("[database] Either quit now, or remove all but one of the argument combinations below:")
+ for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments):
+ print("[database] %d: %s" % (index, attribute_group_name))
+ for attribute_group_name, mismatching_entries in group_by_arguments:
+ response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name))
+ if response == "y":
+ for entry in mismatching_entries:
+ database["sections"].remove(entry)
+ print("[database] Removed %d entry/entries" % len(mismatching_entries))
+
+ # Sanity-check: all mis-matched entries should be removed
+ for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
+ group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
+ if len(group_by_arguments) != 1:
+ print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name))
+ assert len(group_by_arguments) == 1
+
+
def main(argv):
# Parses the command-line arguments
@@ -76,6 +110,9 @@ def main(argv):
new_size = db.length(database)
print("with " + str(new_size - old_size) + " new items") # Newline printed here
+ # Checks for tuning results with mis-matched entries
+ remove_mismatched_arguments(database)
+
# Stores the modified database back to disk
if len(glob.glob(json_files)) >= 1:
io.save_database(database, database_filename)
diff --git a/scripts/database/database/db.py b/scripts/database/database/db.py
index 94948b1a..51c9f1ec 100644
--- a/scripts/database/database/db.py
+++ b/scripts/database/database/db.py
@@ -5,6 +5,9 @@
# Author(s):
# Cedric Nugteren <www.cedricnugteren.nl>
+import itertools
+from operator import itemgetter
+
import clblast
@@ -62,3 +65,14 @@ def combine_result(old_results, new_result):
# No match found: append a new result
old_results.append(new_result)
return old_results
+
+
+def group_by(database, attributes):
+ """Returns an list with the name of the group and the corresponding entries in the database"""
+ assert len(database) > 0
+ attributes = [a for a in attributes if a in database[0]]
+ database.sort(key=itemgetter(*attributes))
+ result = []
+ for key, data in itertools.groupby(database, key=itemgetter(*attributes)):
+ result.append((key, list(data)))
+ return result