summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-06-20 21:19:26 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-06-20 21:19:26 +0200
commite44feb85763c5dbae66beb51892d8dda2126e04c (patch)
tree965ed12eee66d34069a700c1310ddf8d298709f9 /scripts
parent48f2682eb7ee72b0f9e6f2922569fcf352f8ce5f (diff)
Changed the structure of the database to reduce compilation time and save memory
Diffstat (limited to 'scripts')
-rw-r--r--scripts/database/database/clblast.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py
index 8af3ab5b..542ff7b7 100644
--- a/scripts/database/database/clblast.py
+++ b/scripts/database/database/clblast.py
@@ -67,7 +67,7 @@ def get_cpp_precision(family, precision):
"""Retrieves the C++ code for the start of a new precision"""
precision_string = precision_to_string(precision)
camelcase_name = family.title().replace("_", "")
- return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s, {\n"
+ return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s"
% (camelcase_name, precision_string, camelcase_name, precision_string))
@@ -108,6 +108,15 @@ def print_cpp_database(database, output_dir):
and s["device_type"] == DEVICE_TYPE_DEFAULT
and s["device"] == DEVICE_NAME_DEFAULT]
+ # Discovers the parameters for this kernel
+ parameter_names = []
+ for example_data in precision_database:
+ for example_result in example_data["results"]:
+ parameter_names.extend([str(k) for k in example_result["parameters"].keys()])
+ parameter_names = sorted(set(parameter_names))
+ parameter_names_as_string = " ".join(['"%s"' % p for p in parameter_names])
+ f.write(", {" + parameter_names_as_string + "}, {\n")
+
# Loops over device vendors (e.g. AMD)
device_vendors = sorted(set([s["device_vendor"] for s in precision_database]))
for vendor in device_vendors:
@@ -129,6 +138,7 @@ def print_cpp_database(database, output_dir):
# Collects the parameters for this entry
parameters = []
+ parmameter_index = 0
kernels = sorted(set([s["kernel"] for s in device_database]))
for kernel in kernels:
kernel_database = [s for s in device_database if s["kernel"] == kernel]
@@ -139,8 +149,10 @@ def print_cpp_database(database, output_dir):
assert len(results) == 1
new_parameters = results[0]["parameters"]
for parameter_name in sorted(new_parameters):
+ assert parameter_name == parameter_names[parmameter_index]
parameter_value = new_parameters[parameter_name]
- parameters.append("{\"" + parameter_name + "\"," + str(parameter_value) + "}")
+ parameters.append(str(parameter_value))
+ parmameter_index += 1
# Prints the entry
f.write(", ".join(parameters))