diff options
Diffstat (limited to 'scripts/database/database.py')
-rw-r--r-- | scripts/database/database.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/scripts/database/database.py b/scripts/database/database.py index 8e8f37f8..49bc1801 100644 --- a/scripts/database/database.py +++ b/scripts/database/database.py @@ -143,7 +143,8 @@ def CalculateDefaults(df): groups = dfdefault.groupby(DEVICETYPE_ATTRIBUTES+KERNEL_ATTRIBUTES+["kernel"]) for name, dfgroup in groups: if len(dfgroup) != 1: - print("[WARNING] Entries for a single kernel with multiple argument values") + description = dfgroup["kernel"].min() + " " + dfgroup["device_vendor"].min() + print("[WARNING] Entries for a single kernel with multiple argument values: " + description) # Defaults in general groups = df.groupby(KERNEL_ATTRIBUTES+ARGUMENT_ATTRIBUTES+["kernel"]) @@ -189,13 +190,20 @@ def GetFooter(): # The start of a new C++ precision entry def GetPrecision(family, precision): - precisionstring = "Single" - if precision == "64": + precisionstring = "" + if precision == "16": + precisionstring = "Half" + elif precision == "32": + precisionstring = "Single" + elif precision == "64": precisionstring = "Double" elif precision == "3232": precisionstring = "ComplexSingle" elif precision == "6464": precisionstring = "ComplexDouble" + else: + print("[ERROR] Unknown precision") + sys.exit() return("\n\nconst Database::DatabaseEntry Database::%s%s = {\n \"%s\", Precision::k%s, {\n" % (family.title(), precisionstring, family.title(), precisionstring)) @@ -211,7 +219,7 @@ def PrintData(df, outputdir): # Iterates over the kernel families: creates a new file per family for family, dffamily in df.groupby(["kernel_family"]): dffamily = dffamily.dropna(axis=1, how='all') - f = open(os.path.join(outputdir, family+'.h'), 'w+') + f = open(os.path.join(outputdir, family+'.hpp'), 'w+') f.write(GetHeader(family)) # Loops over the different entries for this family and prints their headers @@ -294,6 +302,11 @@ if len(glob.glob(glob_json)) >= 1: print("## Storing the database to disk...") SaveDatabase(database, file_db) +# Optional: update the database here. Default is disabled, code below is just an example +if False: + database = UpdateDatabase(database, ((database["kernel"] == "CopyMatrixFast") & (database["precision"] == "3232")), "arg_alpha", "2+0.5i") + SaveDatabase(database, file_db) + # Retrieves the best performing results print("## Calculating the best results per device/kernel...") bests = GetBestResults(database) @@ -303,7 +316,7 @@ defaults = CalculateDefaults(bests) bests = ConcatenateData(bests, defaults) # Outputs the data as a C++ database -path_cpp_database = os.path.join(path_clblast, "include", "internal", "database") +path_cpp_database = os.path.join(path_clblast, "src", "database", "kernels") print("## Producing a C++ database in '"+path_cpp_database+"'...") PrintData(bests, path_cpp_database) |