From 606990af6f7297528dcc44f67ce777e1ba56d2d0 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Mon, 20 Nov 2017 21:38:23 +0100 Subject: Made the database script properly handle multiple entries for a single device --- scripts/database/database/clblast.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) (limited to 'scripts') diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py index 1a541fff..29778ecb 100644 --- a/scripts/database/database/clblast.py +++ b/scripts/database/database/clblast.py @@ -99,6 +99,36 @@ def print_as_name(name): return "Name{\"%-50s\"}" % name.strip()[:STRING_LENGTH] +def get_kernel_database_results(kernel_database): + """Retrieves the best result from a group of results. Asserts for valid data""" + assert len(kernel_database) >= 1 + + all_results = [item["results"] for item in kernel_database] + + best_results = all_results[0] + for results in all_results: + + # Debugging in case of unexpected results + length_assumption = (len(results) == 1) + params_assumption = (sorted(results[0]["parameters"]) == sorted(best_results[0]["parameters"])) + if not length_assumption or not params_assumption: + print("[database] ERROR: Found %d kernel databases, expected 1" % len(kernel_database)) + all_keys = sorted([key for item in kernel_database for key in item.keys()]) + missing_keys = set([x for x in all_keys if all_keys.count(x) != len(kernel_database)]) + print("[database] All keys in databases: %s" % str(set(all_keys))) + print("[database] Missing keys in one or more databases: %s" % str(missing_keys)) + for index, item in enumerate(kernel_database): + print("[database] %d:" % index) + print(item) + assert length_assumption + assert params_assumption + + if results[0]["time"] < best_results[0]["time"]: + best_results = results + + return best_results + + def print_cpp_database(database, output_dir): """Outputs the database as C++ code""" @@ -173,8 +203,7 @@ def print_cpp_database(database, output_dir): kernels = sorted(set([s["kernel"] for s in device_database])) for kernel in kernels: kernel_database = [s for s in device_database if s["kernel"] == kernel] - assert len(kernel_database) == 1 - results = kernel_database[0]["results"] + results = get_kernel_database_results(kernel_database) assert len(results) == 1 new_parameters = results[0]["parameters"] -- cgit v1.2.3