summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/database/database/clblast.py16
-rwxr-xr-xscripts/generator/generator.py2
2 files changed, 15 insertions, 3 deletions
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py
index 8af3ab5b..803d1d2a 100644
--- a/scripts/database/database/clblast.py
+++ b/scripts/database/database/clblast.py
@@ -67,7 +67,7 @@ def get_cpp_precision(family, precision):
"""Retrieves the C++ code for the start of a new precision"""
precision_string = precision_to_string(precision)
camelcase_name = family.title().replace("_", "")
- return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s, {\n"
+ return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s"
% (camelcase_name, precision_string, camelcase_name, precision_string))
@@ -108,6 +108,15 @@ def print_cpp_database(database, output_dir):
and s["device_type"] == DEVICE_TYPE_DEFAULT
and s["device"] == DEVICE_NAME_DEFAULT]
+ # Discovers the parameters for this kernel
+ parameter_names = []
+ for example_data in precision_database:
+ for example_result in example_data["results"]:
+ parameter_names.extend([str(k) for k in example_result["parameters"].keys()])
+ parameter_names = sorted(set(parameter_names))
+ parameter_names_as_string = ", ".join(['"%s"' % p for p in parameter_names])
+ f.write(", {" + parameter_names_as_string + "}, {\n")
+
# Loops over device vendors (e.g. AMD)
device_vendors = sorted(set([s["device_vendor"] for s in precision_database]))
for vendor in device_vendors:
@@ -129,6 +138,7 @@ def print_cpp_database(database, output_dir):
# Collects the parameters for this entry
parameters = []
+ parmameter_index = 0
kernels = sorted(set([s["kernel"] for s in device_database]))
for kernel in kernels:
kernel_database = [s for s in device_database if s["kernel"] == kernel]
@@ -139,8 +149,10 @@ def print_cpp_database(database, output_dir):
assert len(results) == 1
new_parameters = results[0]["parameters"]
for parameter_name in sorted(new_parameters):
+ assert parameter_name == parameter_names[parmameter_index]
parameter_value = new_parameters[parameter_name]
- parameters.append("{\"" + parameter_name + "\"," + str(parameter_value) + "}")
+ parameters.append(str(parameter_value))
+ parmameter_index += 1
# Prints the entry
f.write(", ".join(parameters))
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 0d0ee29c..74e0815a 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -43,7 +43,7 @@ FILES = [
"/src/clblast_netlib_c.cpp",
]
HEADER_LINES = [122, 78, 126, 24, 29, 41, 29, 65, 32]
-FOOTER_LINES = [25, 139, 27, 38, 6, 6, 6, 9, 2]
+FOOTER_LINES = [25, 147, 27, 38, 6, 6, 6, 9, 2]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63