summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-11-20 21:38:23 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-11-20 21:38:23 +0100
commit606990af6f7297528dcc44f67ce777e1ba56d2d0 (patch)
tree477ed2d41ae70f24230be9afd4d4d3999ed4133f /scripts
parent0f080bbc6e269f686f3cd58a7a2395b96e9cde08 (diff)
Made the database script properly handle multiple entries for a single device
Diffstat (limited to 'scripts')
-rw-r--r--scripts/database/database/clblast.py33
1 files changed, 31 insertions, 2 deletions
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py
index 1a541fff..29778ecb 100644
--- a/scripts/database/database/clblast.py
+++ b/scripts/database/database/clblast.py
@@ -99,6 +99,36 @@ def print_as_name(name):
return "Name{\"%-50s\"}" % name.strip()[:STRING_LENGTH]
+def get_kernel_database_results(kernel_database):
+ """Retrieves the best result from a group of results. Asserts for valid data"""
+ assert len(kernel_database) >= 1
+
+ all_results = [item["results"] for item in kernel_database]
+
+ best_results = all_results[0]
+ for results in all_results:
+
+ # Debugging in case of unexpected results
+ length_assumption = (len(results) == 1)
+ params_assumption = (sorted(results[0]["parameters"]) == sorted(best_results[0]["parameters"]))
+ if not length_assumption or not params_assumption:
+ print("[database] ERROR: Found %d kernel databases, expected 1" % len(kernel_database))
+ all_keys = sorted([key for item in kernel_database for key in item.keys()])
+ missing_keys = set([x for x in all_keys if all_keys.count(x) != len(kernel_database)])
+ print("[database] All keys in databases: %s" % str(set(all_keys)))
+ print("[database] Missing keys in one or more databases: %s" % str(missing_keys))
+ for index, item in enumerate(kernel_database):
+ print("[database] %d:" % index)
+ print(item)
+ assert length_assumption
+ assert params_assumption
+
+ if results[0]["time"] < best_results[0]["time"]:
+ best_results = results
+
+ return best_results
+
+
def print_cpp_database(database, output_dir):
"""Outputs the database as C++ code"""
@@ -173,8 +203,7 @@ def print_cpp_database(database, output_dir):
kernels = sorted(set([s["kernel"] for s in device_database]))
for kernel in kernels:
kernel_database = [s for s in device_database if s["kernel"] == kernel]
- assert len(kernel_database) == 1
- results = kernel_database[0]["results"]
+ results = get_kernel_database_results(kernel_database)
assert len(results) == 1
new_parameters = results[0]["parameters"]