diff options
Diffstat (limited to 'scripts/database/database.py')
-rwxr-xr-x | scripts/database/database.py | 48 |
1 files changed, 20 insertions, 28 deletions
diff --git a/scripts/database/database.py b/scripts/database/database.py index e115d68c..f758a2b7 100755 --- a/scripts/database/database.py +++ b/scripts/database/database.py @@ -11,8 +11,6 @@ import os.path import glob import argparse -import pandas as pd - import database.io as io import database.db as db import database.clblast as clblast @@ -20,15 +18,15 @@ import database.bests as bests import database.defaults as defaults # Server storing a copy of the database -DATABASE_SERVER_URL = "http://www.cedricnugteren.nl/tuning/clblast.db" +DATABASE_SERVER_URL = "http://www.cedricnugteren.nl/tuning/clblast.json" # OpenCL vendor names and their short name -VENDOR_TRANSLATION_TABLE = {"device_vendor": { +VENDOR_TRANSLATION_TABLE = { "GenuineIntel": "Intel", "Intel(R) Corporation": "Intel", "Advanced Micro Devices, Inc.": "AMD", "NVIDIA Corporation": "NVIDIA", -}} +} def main(argv): @@ -41,7 +39,8 @@ def main(argv): cl_args = parser.parse_args(argv) # Parses the path arguments - database_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database.db") + database_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database.json") + database_best_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database_best.json") json_files = os.path.join(cl_args.source_folder, "*.json") cpp_database_path = os.path.join(cl_args.clblast_root, "src", "database", "kernels") @@ -52,11 +51,6 @@ def main(argv): if len(glob.glob(json_files)) < 1: print("[database] The path '" + cl_args.source_folder + "' does not contain any JSON files") - # Pandas options - pd.set_option('display.width', 1000) - if cl_args.verbose: - print("[database] Using pandas version " + pd.__version__) - # Downloads the database if a local copy is not present if not os.path.isfile(database_filename): io.download_database(database_filename, DATABASE_SERVER_URL) @@ -68,38 +62,36 @@ def main(argv): for file_json in glob.glob(json_files): # Loads the newly imported data - sys.stdout.write("[database] Processing '"+file_json+"' ") # No newline printed - imported_data = io.load_json_to_pandas(file_json) + sys.stdout.write("[database] Processing '" + file_json + "' ") # No newline printed + imported_data = io.load_tuning_results(file_json) # Fixes the problem that some vendors use multiple different names - imported_data = db.find_and_replace(imported_data, VENDOR_TRANSLATION_TABLE) + for target in VENDOR_TRANSLATION_TABLE: + if imported_data["device_vendor"] == target: + imported_data["device_vendor"] = VENDOR_TRANSLATION_TABLE[target] # Adds the new data to the database - old_size = len(database.index) - database = db.concatenate_database(database, imported_data) - database = db.remove_duplicates(database) - new_size = len(database.index) + old_size = db.length(database) + database = db.add_section(database, imported_data) + new_size = db.length(database) print("with " + str(new_size - old_size) + " new items") # Newline printed here # Stores the modified database back to disk if len(glob.glob(json_files)) >= 1: io.save_database(database, database_filename) - # Optional: update the database here. Default is disabled, code below is just an example - if False: # TODO: Use command-line arguments to enable updates in a flexible way - database = db.update_database(database, - ((database["kernel"] == "CopyMatrixFast") & - (database["precision"] == "3232")), - "arg_alpha", "2+0.5i") - io.save_database(database, database_filename) - # Retrieves the best performing results print("[database] Calculating the best results per device/kernel...") database_best_results = bests.get_best_results(database) # Determines the defaults for other vendors and per vendor - database_defaults = defaults.calculate_defaults(database_best_results) - database_best_results = db.concatenate_database(database_best_results, database_defaults) + print("[database] Calculating the default values...") + database_defaults = defaults.calculate_defaults(database, cl_args.verbose) + database_best_results["sections"].extend(database_defaults["sections"]) + + # Optionally outputs the database to disk + if cl_args.verbose: + io.save_database(database_best_results, database_best_filename) # Outputs the database as a C++ database print("[database] Producing a C++ database in '" + cpp_database_path + "'...") |