summaryrefslogtreecommitdiff
path: root/scripts/database/database.py
blob: 6d370d998b4a9a178ca5ee8dc88e2a8407c982b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python

# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
# PEP8 Python style guide and uses a max-width of 120 characters per line.
#
# Author(s):
#   Cedric Nugteren <www.cedricnugteren.nl>

import sys
import os.path
import glob
import argparse

import pandas as pd

import database.io as io
import database.db as db
import database.clblast as clblast
import database.bests as bests
import database.defaults as defaults

# Server storing a copy of the database
DATABASE_SERVER_URL = "http://www.cedricnugteren.nl/tuning/clblast.db"

# OpenCL vendor names and their short name
VENDOR_TRANSLATION_TABLE = {"device_vendor": {
  "GenuineIntel": "Intel",
  "Intel(R) Corporation": "Intel",
  "Advanced Micro Devices, Inc.": "AMD",
  "NVIDIA Corporation": "NVIDIA",
}}


def main(argv):

    # Parses the command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("source_folder", help="The folder with JSON files to parse to add to the database")
    parser.add_argument("clblast_root", help="Root of the CLBlast sources")
    parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
    cl_args = parser.parse_args(argv)

    # Parses the path arguments
    database_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database.db")
    json_files = os.path.join(cl_args.source_folder, "*.json")
    cpp_database_path = os.path.join(cl_args.clblast_root, "src", "database", "kernels")

    # Checks whether the command-line arguments are valid
    clblast_header = os.path.join(cl_args.clblast_root, "include", "clblast.h")  # Not used but just for validation
    if not os.path.isfile(clblast_header):
        raise RuntimeError("The path '" + cl_args.clblast_root + "' does not point to the root of the CLBlast library")
    if len(glob.glob(json_files)) < 1:
        print("[database] The path '" + cl_args.source_folder + "' does not contain any JSON files")

    # Pandas options
    pd.set_option('display.width', 1000)
    if cl_args.verbose:
        print("[database] Using pandas version " + pd.__version__)

    # Downloads the database if a local copy is not present
    if not os.path.isfile(database_filename):
        io.download_database(database_filename, DATABASE_SERVER_URL)

    # Loads the database from disk
    database = io.load_database(database_filename)

    # Loops over all JSON files in the supplied folder
    for file_json in glob.glob(json_files):

        # Loads the newly imported data
        sys.stdout.write("[database] Processing '"+file_json+"' ")  # No newline printed
        imported_data = io.load_json_to_pandas(file_json)

        # Fixes the problem that some vendors use multiple different names
        imported_data = db.find_and_replace(imported_data, VENDOR_TRANSLATION_TABLE)

        # Adds the new data to the database
        old_size = len(database.index)
        database = db.concatenate_database(database, imported_data)
        database = db.remove_duplicates(database)
        new_size = len(database.index)
        print("with " + str(new_size - old_size) + " new items")  # Newline printed here

    # Stores the modified database back to disk
    if len(glob.glob(json_files)) >= 1:
        io.save_database(database, database_filename)

    # Optional: update the database here. Default is disabled, code below is just an example
    if False:  # TODO: Use command-line arguments to enable updates in a flexible way
        database = db.update_database(database,
                                      ((database["kernel"] == "CopyMatrixFast") &
                                       (database["precision"] == "3232")),
                                      "arg_alpha", "2+0.5i")
        io.save_database(database, database_filename)

    # Retrieves the best performing results
    print("[database] Calculating the best results per device/kernel...")
    database_best_results = bests.get_best_results(database)

    # Determines the defaults for other vendors and per vendor
    print("[database] Calculating the default values...")
    database_defaults = defaults.calculate_defaults(database, cl_args.verbose)
    database_best_results = db.concatenate_database(database_best_results, database_defaults)

    # Outputs the database as a C++ database
    print("[database] Producing a C++ database in '" + cpp_database_path + "'...")
    clblast.print_cpp_database(database_best_results, cpp_database_path)

    print("[database] All done")


if __name__ == '__main__':
    main(sys.argv[1:])