diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-06-20 21:19:26 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-06-20 21:19:26 +0200 |
commit | e44feb85763c5dbae66beb51892d8dda2126e04c (patch) | |
tree | 965ed12eee66d34069a700c1310ddf8d298709f9 /scripts | |
parent | 48f2682eb7ee72b0f9e6f2922569fcf352f8ce5f (diff) |
Changed the structure of the database to reduce compilation time and save memory
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/database/database/clblast.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py index 8af3ab5b..542ff7b7 100644 --- a/scripts/database/database/clblast.py +++ b/scripts/database/database/clblast.py @@ -67,7 +67,7 @@ def get_cpp_precision(family, precision): """Retrieves the C++ code for the start of a new precision""" precision_string = precision_to_string(precision) camelcase_name = family.title().replace("_", "") - return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s, {\n" + return("\n\nconst Database::DatabaseEntry %s%s = {\n \"%s\", Precision::k%s" % (camelcase_name, precision_string, camelcase_name, precision_string)) @@ -108,6 +108,15 @@ def print_cpp_database(database, output_dir): and s["device_type"] == DEVICE_TYPE_DEFAULT and s["device"] == DEVICE_NAME_DEFAULT] + # Discovers the parameters for this kernel + parameter_names = [] + for example_data in precision_database: + for example_result in example_data["results"]: + parameter_names.extend([str(k) for k in example_result["parameters"].keys()]) + parameter_names = sorted(set(parameter_names)) + parameter_names_as_string = " ".join(['"%s"' % p for p in parameter_names]) + f.write(", {" + parameter_names_as_string + "}, {\n") + # Loops over device vendors (e.g. AMD) device_vendors = sorted(set([s["device_vendor"] for s in precision_database])) for vendor in device_vendors: @@ -129,6 +138,7 @@ def print_cpp_database(database, output_dir): # Collects the parameters for this entry parameters = [] + parmameter_index = 0 kernels = sorted(set([s["kernel"] for s in device_database])) for kernel in kernels: kernel_database = [s for s in device_database if s["kernel"] == kernel] @@ -139,8 +149,10 @@ def print_cpp_database(database, output_dir): assert len(results) == 1 new_parameters = results[0]["parameters"] for parameter_name in sorted(new_parameters): + assert parameter_name == parameter_names[parmameter_index] parameter_value = new_parameters[parameter_name] - parameters.append("{\"" + parameter_name + "\"," + str(parameter_value) + "}") + parameters.append(str(parameter_value)) + parmameter_index += 1 # Prints the entry f.write(", ".join(parameters)) |