summaryrefslogtreecommitdiff
path: root/scripts/database/database.py
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-09-16 20:37:09 +0200
committerGitHub <noreply@github.com>2017-09-16 20:37:09 +0200
commit7d0ef8e10d05ee3a18360295c021ab6a6ef32c2d (patch)
tree55d795f06769134601f017f50d505a6c8904d398 /scripts/database/database.py
parentbb947890dec90712c92028c20234eafd48e6fa3e (diff)
parentbcf39eb79a8252b9f9b0c31311c7951abc8520ee (diff)
Merge pull request #191 from CNugteren/database_improvements
Database improvements
Diffstat (limited to 'scripts/database/database.py')
-rwxr-xr-xscripts/database/database.py29
1 files changed, 10 insertions, 19 deletions
diff --git a/scripts/database/database.py b/scripts/database/database.py
index e398aa30..8f3ccce6 100755
--- a/scripts/database/database.py
+++ b/scripts/database/database.py
@@ -20,14 +20,6 @@ import database.defaults as defaults
# Server storing a copy of the database
DATABASE_SERVER_URL = "https://raw.githubusercontent.com/CNugteren/CLBlast-database/master/database.json"
-# OpenCL vendor names and their short name
-VENDOR_TRANSLATION_TABLE = {
- "GenuineIntel": "Intel",
- "Intel(R) Corporation": "Intel",
- "Advanced Micro Devices, Inc.": "AMD",
- "NVIDIA Corporation": "NVIDIA",
-}
-
def remove_mismatched_arguments(database):
"""Checks for tuning results with mis-matched entries and removes them according to user preferences"""
@@ -44,12 +36,14 @@ def remove_mismatched_arguments(database):
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
if len(group_by_arguments) != 1:
- print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name))
- print("[database] Either quit now, or remove all but one of the argument combinations below:")
+ print("[database] WARNING: entries for a single kernel with multiple argument values " +
+ str(kernel_group_name))
+ print("[database] Either quit or remove all but one of the argument combinations below:")
for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments):
print("[database] %d: %s" % (index, attribute_group_name))
for attribute_group_name, mismatching_entries in group_by_arguments:
- response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name))
+ response = user_input("[database] Remove entries corresponding to %s, [y/n]? " %
+ str(attribute_group_name))
if response == "y":
for entry in mismatching_entries:
database["sections"].remove(entry)
@@ -59,7 +53,8 @@ def remove_mismatched_arguments(database):
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
if len(group_by_arguments) != 1:
- print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name))
+ print("[database] ERROR: entries for a single kernel with multiple argument values " +
+ str(kernel_group_name))
assert len(group_by_arguments) == 1
@@ -97,7 +92,8 @@ def main(argv):
# Checks whether the command-line arguments are valid
clblast_header = os.path.join(cl_args.clblast_root, "include", "clblast.h") # Not used but just for validation
if not os.path.isfile(clblast_header):
- raise RuntimeError("The path '" + cl_args.clblast_root + "' does not point to the root of the CLBlast library")
+ raise RuntimeError("The path '" + cl_args.clblast_root +
+ "' does not point to the root of the CLBlast library")
if len(glob.glob(json_files)) < 1:
print("[database] The path '" + cl_args.source_folder + "' does not contain any JSON files")
@@ -115,11 +111,6 @@ def main(argv):
sys.stdout.write("[database] Processing '" + file_json + "' ") # No newline printed
imported_data = io.load_tuning_results(file_json)
- # Fixes the problem that some vendors use multiple different names
- for target in VENDOR_TRANSLATION_TABLE:
- if imported_data["device_vendor"] == target:
- imported_data["device_vendor"] = VENDOR_TRANSLATION_TABLE[target]
-
# Adds the new data to the database
old_size = db.length(database)
database = db.add_section(database, imported_data)
@@ -136,7 +127,7 @@ def main(argv):
# Removes database entries before continuing
if cl_args.remove_device is not None:
print("[database] Removing all results for device '%s'" % cl_args.remove_device)
- remove_database_entries(database, {"device": cl_args.remove_device})
+ remove_database_entries(database, {"clblast_device": cl_args.remove_device})
io.save_database(database, database_filename)
# Retrieves the best performing results