diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/database/database/clblast.py | 16 | ||||
-rwxr-xr-x | scripts/generator/generator.py | 2 |
2 files changed, 15 insertions, 3 deletions
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py index 8af3ab5b..803d1d2a 100644 --- a/scripts/database/database/clblast.py +++ b/scripts/database/database/clblast.py @@ -67,7 +67,7 @@ def get_cpp_precision(family, precision): """Retrieves the C++ code for the start of a new precision""" precision_string = precision_to_string(precision) camelcase_name = family.title().replace("_", "") - return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s, {\n" + return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s" % (camelcase_name, precision_string, camelcase_name, precision_string)) @@ -108,6 +108,15 @@ def print_cpp_database(database, output_dir): and s["device_type"] == DEVICE_TYPE_DEFAULT and s["device"] == DEVICE_NAME_DEFAULT] + # Discovers the parameters for this kernel + parameter_names = [] + for example_data in precision_database: + for example_result in example_data["results"]: + parameter_names.extend([str(k) for k in example_result["parameters"].keys()]) + parameter_names = sorted(set(parameter_names)) + parameter_names_as_string = ", ".join(['"%s"' % p for p in parameter_names]) + f.write(", {" + parameter_names_as_string + "}, {\n") + # Loops over device vendors (e.g. AMD) device_vendors = sorted(set([s["device_vendor"] for s in precision_database])) for vendor in device_vendors: @@ -129,6 +138,7 @@ def print_cpp_database(database, output_dir): # Collects the parameters for this entry parameters = [] + parmameter_index = 0 kernels = sorted(set([s["kernel"] for s in device_database])) for kernel in kernels: kernel_database = [s for s in device_database if s["kernel"] == kernel] @@ -139,8 +149,10 @@ def print_cpp_database(database, output_dir): assert len(results) == 1 new_parameters = results[0]["parameters"] for parameter_name in sorted(new_parameters): + assert parameter_name == parameter_names[parmameter_index] parameter_value = new_parameters[parameter_name] - parameters.append("{\"" + parameter_name + "\"," + str(parameter_value) + "}") + parameters.append(str(parameter_value)) + parmameter_index += 1 # Prints the entry f.write(", ".join(parameters)) diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 0d0ee29c..74e0815a 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -43,7 +43,7 @@ FILES = [ "/src/clblast_netlib_c.cpp", ] HEADER_LINES = [122, 78, 126, 24, 29, 41, 29, 65, 32] -FOOTER_LINES = [25, 139, 27, 38, 6, 6, 6, 9, 2] +FOOTER_LINES = [25, 147, 27, 38, 6, 6, 6, 9, 2] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 63 |