summaryrefslogtreecommitdiff
path: root/scripts/database/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/database/database.py')
-rw-r--r--scripts/database/database.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/scripts/database/database.py b/scripts/database/database.py
index 8e8f37f8..49bc1801 100644
--- a/scripts/database/database.py
+++ b/scripts/database/database.py
@@ -143,7 +143,8 @@ def CalculateDefaults(df):
groups = dfdefault.groupby(DEVICETYPE_ATTRIBUTES+KERNEL_ATTRIBUTES+["kernel"])
for name, dfgroup in groups:
if len(dfgroup) != 1:
- print("[WARNING] Entries for a single kernel with multiple argument values")
+ description = dfgroup["kernel"].min() + " " + dfgroup["device_vendor"].min()
+ print("[WARNING] Entries for a single kernel with multiple argument values: " + description)
# Defaults in general
groups = df.groupby(KERNEL_ATTRIBUTES+ARGUMENT_ATTRIBUTES+["kernel"])
@@ -189,13 +190,20 @@ def GetFooter():
# The start of a new C++ precision entry
def GetPrecision(family, precision):
- precisionstring = "Single"
- if precision == "64":
+ precisionstring = ""
+ if precision == "16":
+ precisionstring = "Half"
+ elif precision == "32":
+ precisionstring = "Single"
+ elif precision == "64":
precisionstring = "Double"
elif precision == "3232":
precisionstring = "ComplexSingle"
elif precision == "6464":
precisionstring = "ComplexDouble"
+ else:
+ print("[ERROR] Unknown precision")
+ sys.exit()
return("\n\nconst Database::DatabaseEntry Database::%s%s = {\n \"%s\", Precision::k%s, {\n"
% (family.title(), precisionstring, family.title(), precisionstring))
@@ -211,7 +219,7 @@ def PrintData(df, outputdir):
# Iterates over the kernel families: creates a new file per family
for family, dffamily in df.groupby(["kernel_family"]):
dffamily = dffamily.dropna(axis=1, how='all')
- f = open(os.path.join(outputdir, family+'.h'), 'w+')
+ f = open(os.path.join(outputdir, family+'.hpp'), 'w+')
f.write(GetHeader(family))
# Loops over the different entries for this family and prints their headers
@@ -294,6 +302,11 @@ if len(glob.glob(glob_json)) >= 1:
print("## Storing the database to disk...")
SaveDatabase(database, file_db)
+# Optional: update the database here. Default is disabled, code below is just an example
+if False:
+ database = UpdateDatabase(database, ((database["kernel"] == "CopyMatrixFast") & (database["precision"] == "3232")), "arg_alpha", "2+0.5i")
+ SaveDatabase(database, file_db)
+
# Retrieves the best performing results
print("## Calculating the best results per device/kernel...")
bests = GetBestResults(database)
@@ -303,7 +316,7 @@ defaults = CalculateDefaults(bests)
bests = ConcatenateData(bests, defaults)
# Outputs the data as a C++ database
-path_cpp_database = os.path.join(path_clblast, "include", "internal", "database")
+path_cpp_database = os.path.join(path_clblast, "src", "database", "kernels")
print("## Producing a C++ database in '"+path_cpp_database+"'...")
PrintData(bests, path_cpp_database)