summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-07-24 17:06:27 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-07-24 17:06:27 +0200
commit0252df731ab21d4acd5dfe53733e3c1bd0c18200 (patch)
treefae635b9978bf570e2c56735d77c25628f6bbd0c
parent40a72259eba491631d8875aae465c5a93d7fed02 (diff)
parentffa35c623af4b0916f625f3a41000e75a1df7e1f (diff)
Merge branch 'development' into gemv_performance
-rw-r--r--CMakeLists.txt2
-rwxr-xr-x[-rw-r--r--]scripts/database/database.py356
-rw-r--r--scripts/database/database/__init__.py0
-rw-r--r--scripts/database/database/bests.py20
-rw-r--r--scripts/database/database/clblast.py132
-rw-r--r--scripts/database/database/db.py50
-rw-r--r--scripts/database/database/defaults.py58
-rw-r--r--scripts/database/database/io.py58
-rw-r--r--src/clpp11.hpp19
-rw-r--r--src/database/database.cpp55
-rw-r--r--src/database/database.hpp12
-rw-r--r--src/routine.cpp5
-rw-r--r--src/routine.hpp6
-rw-r--r--src/routines/common.cpp40
-rw-r--r--src/routines/common.hpp9
-rw-r--r--src/routines/level3/xgemm.cpp42
16 files changed, 489 insertions, 375 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 77d1cd08..95d1d500 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -95,7 +95,7 @@ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CFLAGS}")
# ==================================================================================================
# Package scripts location
-set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/")
+set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${clblast_SOURCE_DIR}/cmake/Modules/")
# Requires OpenCL. It is found through the included "FindOpenCL.cmake" in CMAKE_MODULE_PATH.
find_package(OpenCL REQUIRED)
diff --git a/scripts/database/database.py b/scripts/database/database.py
index a70b9fc1..e115d68c 100644..100755
--- a/scripts/database/database.py
+++ b/scripts/database/database.py
@@ -1,326 +1,112 @@
#!/usr/bin/env python
-# ==================================================================================================
-# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
-# project loosely follows the Google C++ styleguide and uses a max-width of 100 characters per line.
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
#
# Author(s):
# Cedric Nugteren <www.cedricnugteren.nl>
-#
-# ==================================================================================================
-# System modules
import sys
import os.path
import glob
-import re
-import json
-try:
- from urllib.request import urlopen # Python 3
-except ImportError:
- from urllib2 import urlopen # Python 2
+import argparse
-# Additional modules
import pandas as pd
-print("## Using pandas version "+pd.__version__+", requires at least 0.17.0")
+
+import database.io as io
+import database.db as db
+import database.clblast as clblast
+import database.bests as bests
+import database.defaults as defaults
# Server storing a copy of the database
DATABASE_SERVER_URL = "http://www.cedricnugteren.nl/tuning/clblast.db"
-# Constants
-VENDOR_DEFAULT = "default"
-DEVICETYPE_DEFAULT = "All"
-DEVICENAME_DEFAULT = "default"
-
-# Attributes
-DEVICETYPE_ATTRIBUTES = ["device_vendor", "device_type"]
-DEVICE_ATTRIBUTES = ["device", "device_core_clock", "device_compute_units"]
-KERNEL_ATTRIBUTES = ["precision", "kernel_family"]
-ARGUMENT_ATTRIBUTES = ["arg_m", "arg_n", "arg_k", "arg_alpha", "arg_beta"]
-ATTRIBUTES = DEVICE_ATTRIBUTES + DEVICETYPE_ATTRIBUTES + KERNEL_ATTRIBUTES + ARGUMENT_ATTRIBUTES
-
# OpenCL vendor names and their short name
-VENDOR_NAMES = { "device_vendor": {
+VENDOR_TRANSLATION_TABLE = {"device_vendor": {
"GenuineIntel": "Intel",
"Intel(R) Corporation": "Intel",
"Advanced Micro Devices, Inc.": "AMD",
"NVIDIA Corporation": "NVIDIA",
}}
-# Pandas options
-pd.set_option('display.width', 1000)
-
-# ==================================================================================================
-# Database operations
-# ==================================================================================================
-
-# Downloads the database and save it to disk
-def DownloadDatabase(filename):
- print("## Downloading database from '"+DATABASE_SERVER_URL+"'...")
- df = urlopen(DATABASE_SERVER_URL)
- output = open(file_db,'wb')
- output.write(df.read())
- output.close()
-
-# Loads the database from disk
-def LoadDatabase(filename):
- return pd.read_pickle(filename)
-
-# Saves the database to disk
-def SaveDatabase(df, filename):
- df.to_pickle(filename)
-
-# Loads JSON data from file
-def ImportDataFromFile(filename):
- with open(filename) as f:
- data = json.load(f)
- json_data = pd.DataFrame(data)
- df = pd.io.json.json_normalize(json_data["results"])
- for attribute in ATTRIBUTES:
- if attribute == "kernel_family":
- df[attribute] = re.sub(r'_\d+', '', data[attribute])
- elif attribute in data:
- df[attribute] = data[attribute]
- else:
- df[attribute] = 0
- return df
-
-# Returns the row-wise concatenation of two dataframes
-def ConcatenateData(df1, df2):
- return pd.concat([df1, df2])
-
-# Removes duplicates from a dataframe
-def RemoveDuplicates(df):
- return df.drop_duplicates()
-
-# database = database[(database["device"] != "AMD Radeon R9 M370X Compute Engine") | (database["kernel_family"] != "xgemm") | (database["precision"] != "32")]
-def RemoveEntriesByDevice(df, devicename):
- return df[df["device"] != devicename]
-
-def RemoveEntriesByKernelFamily(df, familyname):
- return df[df["kernel_family"] != familyname]
-
-def GetEntriesByField(df, field, value):
- return df[df[field] == value]
-
-# Example usage:
-# df = UpdateDatabase(df, (df["kernel_family"] == "xdot") & (df["arg_n"] == "67108864"), "arg_n", "2097152")
-def UpdateDatabase(df, condition, field, value):
- df.loc[condition, field] = value
- return df
-
-# Fixes the problem that some vendors use multiple different names
-def SanitizeVendorNames(df):
- df = df.replace(VENDOR_NAMES)
- return df
-
-# Retrieves the results with the lowest execution times
-def GetBestResults(df):
- dfbest = pd.DataFrame()
- grouped = df.groupby(ATTRIBUTES+["kernel"])
- for name, dfgroup in grouped:
- besttime = dfgroup["time"].min()
- bestcase = dfgroup[dfgroup["time"] == besttime].iloc[0]
- dfbest = dfbest.append(bestcase, ignore_index=True)
- return dfbest
-
-# Sets defaults for devices of the same type/vendor based on the smallest values of all know
-# entries. The average might be better for performance but some parameters might not be supported
-# on other devices.
-def CalculateDefaults(df):
- dfdefault = pd.DataFrame()
-
- # Defaults per type/vendor
- groups = df.groupby(DEVICETYPE_ATTRIBUTES+KERNEL_ATTRIBUTES+ARGUMENT_ATTRIBUTES+["kernel"])
- for name, dfgroup in groups:
- default_values = dfgroup.min(axis=0)
- default_values["device"] = DEVICENAME_DEFAULT
- default_values["device_compute_units"] = 0
- default_values["device_core_clock"] = 0
- default_values["time"] = 0.0
- dfdefault = dfdefault.append(default_values, ignore_index=True)
-
- # Checks for mis-matched arguments
- groups = dfdefault.groupby(DEVICETYPE_ATTRIBUTES+KERNEL_ATTRIBUTES+["kernel"])
- for name, dfgroup in groups:
- if len(dfgroup) != 1:
- description = dfgroup["kernel"].min() + " " + dfgroup["device_vendor"].min()
- print("[WARNING] Entries for a single kernel with multiple argument values: " + description)
-
- # Defaults in general
- groups = df.groupby(KERNEL_ATTRIBUTES+ARGUMENT_ATTRIBUTES+["kernel"])
- for name, dfgroup in groups:
- default_values = dfgroup.min(axis=0)
- default_values["device_vendor"] = VENDOR_DEFAULT
- default_values["device_type"] = DEVICETYPE_DEFAULT
- default_values["device"] = DEVICENAME_DEFAULT
- default_values["device_compute_units"] = 0
- default_values["device_core_clock"] = 0
- default_values["time"] = 0.0
- dfdefault = dfdefault.append(default_values, ignore_index=True)
-
- # Database with both types of defaults only
- return dfdefault
-
-# ==================================================================================================
-# C++ header generation
-# ==================================================================================================
-
-# The C++ header
-def GetHeader(family):
- return("""
-// =================================================================================================
-// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
-// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
-// width of 100 characters per line.
-//
-// Author(s):
-// Database generator <database.py>
-//
-// This file populates the database with best-found tuning parameters for the '%s' kernels.
-//
-// =================================================================================================
-
-namespace clblast {
-// ================================================================================================="""
- % family.title())
-
-# The C++ footer
-def GetFooter():
- return("\n} // namespace clblast\n")
-
-# The start of a new C++ precision entry
-def GetPrecision(family, precision):
- precisionstring = ""
- if precision == "16":
- precisionstring = "Half"
- elif precision == "32":
- precisionstring = "Single"
- elif precision == "64":
- precisionstring = "Double"
- elif precision == "3232":
- precisionstring = "ComplexSingle"
- elif precision == "6464":
- precisionstring = "ComplexDouble"
- else:
- print("[ERROR] Unknown precision")
- sys.exit()
- return("\n\nconst Database::DatabaseEntry Database::%s%s = {\n \"%s\", Precision::k%s, {\n"
- % (family.title(), precisionstring, family.title(), precisionstring))
-
-# The C++ device type and vendor
-def GetDeviceVendor(vendor, devtype):
- if vendor == VENDOR_DEFAULT and devtype == DEVICETYPE_DEFAULT:
- return(" { // Default\n kDeviceType%s, \"%s\", {\n" % (devtype, vendor))
- return(" { // %s %ss\n kDeviceType%s, \"%s\", {\n" % (vendor, devtype, devtype[0].upper() + devtype[1:], vendor))
-
-# Prints the data to a C++ database
-def PrintData(df, outputdir):
-
- # Iterates over the kernel families: creates a new file per family
- for family, dffamily in df.groupby(["kernel_family"]):
- dffamily = dffamily.dropna(axis=1, how='all')
- f = open(os.path.join(outputdir, family+'.hpp'), 'w+')
- f.write(GetHeader(family))
-
- # Loops over the different entries for this family and prints their headers
- for precision, dfprecision in dffamily.groupby(["precision"]):
- f.write(GetPrecision(family, precision))
- for vendor, dfvendor in dfprecision.groupby(["device_vendor"]):
- for devtype, dfdevtype in dfvendor.groupby(["device_type"]):
- f.write(GetDeviceVendor(vendor, devtype))
- for device, dfdevice in dfdevtype.groupby(["device"]):
- devicename = "\"%s\"," % device
- f.write(" { %-50s { " % devicename)
- # Collects the paramaters for this case and prints them
- parameters = []
- for kernel, dfkernel in dfdevice.groupby(["kernel"]):
- dfkernel = dfkernel.dropna(axis=1)
- col_names = [col for col in list(dfkernel) if col.startswith('parameters.') and col != "parameters.PRECISION"]
- parameters += ["{\"%s\",%d}" % (p.replace("parameters.",""), dfkernel[p].iloc[0]) for p in col_names]
- f.write(", ".join(parameters))
- f.write(" } },\n")
+def main(argv):
- # Prints the footers
- f.write(" }\n },\n")
- f.write(" }\n};\n\n// =================================================================================================")
- f.write(GetFooter())
+ # Parses the command-line arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument("source_folder", help="The folder with JSON files to parse to add to the database")
+ parser.add_argument("clblast_root", help="Root of the CLBlast sources")
+ parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
+ cl_args = parser.parse_args(argv)
-# ==================================================================================================
-# Command-line arguments parsing and verification
-# ==================================================================================================
+ # Parses the path arguments
+ database_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database.db")
+ json_files = os.path.join(cl_args.source_folder, "*.json")
+ cpp_database_path = os.path.join(cl_args.clblast_root, "src", "database", "kernels")
-# Checks for the number of command-line arguments
-if len(sys.argv) != 3:
- print("[ERROR] Usage: database.py <folder_with_json_files> <root_of_clblast>")
- sys.exit()
+ # Checks whether the command-line arguments are valid
+ clblast_header = os.path.join(cl_args.clblast_root, "include", "clblast.h") # Not used but just for validation
+ if not os.path.isfile(clblast_header):
+ raise RuntimeError("The path '" + cl_args.clblast_root + "' does not point to the root of the CLBlast library")
+ if len(glob.glob(json_files)) < 1:
+ print("[database] The path '" + cl_args.source_folder + "' does not contain any JSON files")
-# Parses the command-line arguments
-path_json = sys.argv[1]
-path_clblast = sys.argv[2]
-file_db = os.path.join(path_clblast, "scripts", "database", "database.db")
-glob_json = os.path.join(path_json, "*.json")
+ # Pandas options
+ pd.set_option('display.width', 1000)
+ if cl_args.verbose:
+ print("[database] Using pandas version " + pd.__version__)
-# Checks whether the command-line arguments are valid; exists otherwise
-clblast_h = os.path.join(path_clblast, "include", "clblast.h") # Not used but just for validation
-if not os.path.isfile(clblast_h):
- print("[ERROR] The path '"+path_clblast+"' does not point to the root of the CLBlast library")
- sys.exit()
-if len(glob.glob(glob_json)) < 1:
- print("## The path '"+path_json+"' does not contain any JSON files")
+ # Downloads the database if a local copy is not present
+ if not os.path.isfile(database_filename):
+ io.download_database(database_filename, DATABASE_SERVER_URL)
-# ==================================================================================================
-# The main body of the script
-# ==================================================================================================
+ # Loads the database from disk
+ database = io.load_database(database_filename)
-# Downloads the database if a local copy is not present
-db_exists = os.path.isfile(file_db)
-if not db_exists:
- DownloadDatabase(file_db)
+ # Loops over all JSON files in the supplied folder
+ for file_json in glob.glob(json_files):
-# Loads the database from disk
-print("## Loading the database from disk...")
-database = LoadDatabase(file_db)
+ # Loads the newly imported data
+ sys.stdout.write("[database] Processing '"+file_json+"' ") # No newline printed
+ imported_data = io.load_json_to_pandas(file_json)
-# Loops over all JSON files in the supplied folder
-for file_json in glob.glob(glob_json):
+ # Fixes the problem that some vendors use multiple different names
+ imported_data = db.find_and_replace(imported_data, VENDOR_TRANSLATION_TABLE)
- # Loads the newly imported data
- sys.stdout.write("## Processing '"+file_json+"' ")
- imported_data = ImportDataFromFile(file_json)
- imported_data = SanitizeVendorNames(imported_data)
+ # Adds the new data to the database
+ old_size = len(database.index)
+ database = db.concatenate_database(database, imported_data)
+ database = db.remove_duplicates(database)
+ new_size = len(database.index)
+ print("with " + str(new_size - old_size) + " new items") # Newline printed here
- # Adds the new data to the database
- old_size = len(database.index)
- database = ConcatenateData(database, imported_data)
- database = RemoveDuplicates(database)
- new_size = len(database.index)
- print("with "+str(new_size-old_size)+" new items")
+ # Stores the modified database back to disk
+ if len(glob.glob(json_files)) >= 1:
+ io.save_database(database, database_filename)
-# Stores the modified database back to disk
-if len(glob.glob(glob_json)) >= 1:
- print("## Storing the database to disk...")
- SaveDatabase(database, file_db)
+ # Optional: update the database here. Default is disabled, code below is just an example
+ if False: # TODO: Use command-line arguments to enable updates in a flexible way
+ database = db.update_database(database,
+ ((database["kernel"] == "CopyMatrixFast") &
+ (database["precision"] == "3232")),
+ "arg_alpha", "2+0.5i")
+ io.save_database(database, database_filename)
-# Optional: update the database here. Default is disabled, code below is just an example
-if False:
- database = UpdateDatabase(database, ((database["kernel"] == "CopyMatrixFast") & (database["precision"] == "3232")), "arg_alpha", "2+0.5i")
- SaveDatabase(database, file_db)
+ # Retrieves the best performing results
+ print("[database] Calculating the best results per device/kernel...")
+ database_best_results = bests.get_best_results(database)
-# Retrieves the best performing results
-print("## Calculating the best results per device/kernel...")
-bests = GetBestResults(database)
+ # Determines the defaults for other vendors and per vendor
+ database_defaults = defaults.calculate_defaults(database_best_results)
+ database_best_results = db.concatenate_database(database_best_results, database_defaults)
-# Determines the defaults for other vendors and per vendor
-defaults = CalculateDefaults(bests)
-bests = ConcatenateData(bests, defaults)
+ # Outputs the database as a C++ database
+ print("[database] Producing a C++ database in '" + cpp_database_path + "'...")
+ clblast.print_cpp_database(database_best_results, cpp_database_path)
-# Outputs the data as a C++ database
-path_cpp_database = os.path.join(path_clblast, "src", "database", "kernels")
-print("## Producing a C++ database in '"+path_cpp_database+"'...")
-PrintData(bests, path_cpp_database)
+ print("[database] All done")
-print("## All done")
-# ==================================================================================================
+if __name__ == '__main__':
+ main(sys.argv[1:])
diff --git a/scripts/database/database/__init__.py b/scripts/database/database/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/scripts/database/database/__init__.py
diff --git a/scripts/database/database/bests.py b/scripts/database/database/bests.py
new file mode 100644
index 00000000..edb81733
--- /dev/null
+++ b/scripts/database/database/bests.py
@@ -0,0 +1,20 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import pandas as pd
+import clblast
+
+
+def get_best_results(df):
+ """Retrieves the results with the lowests execution times"""
+ database_bests = pd.DataFrame()
+ database_entries = df.groupby(clblast.ATTRIBUTES + ["kernel"])
+ for name, database_entry in database_entries:
+ best_time = database_entry["time"].min()
+ best_parameters = database_entry[database_entry["time"] == best_time].iloc[0]
+ database_bests = database_bests.append(best_parameters, ignore_index=True)
+ return database_bests
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py
new file mode 100644
index 00000000..9c9f7eb4
--- /dev/null
+++ b/scripts/database/database/clblast.py
@@ -0,0 +1,132 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import os
+
+# Constants from the C++ code
+VENDOR_DEFAULT = "default"
+DEVICE_TYPE_DEFAULT = "All"
+DEVICE_NAME_DEFAULT = "default"
+
+# List of attributes
+DEVICE_TYPE_ATTRIBUTES = ["device_vendor", "device_type"]
+DEVICE_ATTRIBUTES = ["device", "device_core_clock", "device_compute_units"]
+KERNEL_ATTRIBUTES = ["precision", "kernel_family"]
+ARGUMENT_ATTRIBUTES = ["arg_m", "arg_n", "arg_k", "arg_alpha", "arg_beta"]
+ATTRIBUTES = DEVICE_ATTRIBUTES + DEVICE_TYPE_ATTRIBUTES + KERNEL_ATTRIBUTES + ARGUMENT_ATTRIBUTES
+
+
+def precision_to_string(precision):
+ """Translates a precision number (represented as Python string) into a descriptive string"""
+ if precision == "16":
+ return "Half"
+ elif precision == "32":
+ return "Single"
+ elif precision == "64":
+ return "Double"
+ elif precision == "3232":
+ return "ComplexSingle"
+ elif precision == "6464":
+ return "ComplexDouble"
+ else:
+ raise("Unknown precision: " + precision)
+
+
+def get_cpp_separator():
+ """Retrieves a C++ comment separator"""
+ return "// ================================================================================================="
+
+
+def get_cpp_header(family):
+ """Retrieves the C++ header"""
+ return ("\n" + get_cpp_separator() + """
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Database generator <database.py>
+//
+// This file populates the database with best-found tuning parameters for the '%s' kernels.
+//\n"""
+ % family.title() + get_cpp_separator() + "\n\nnamespace clblast {\n" + get_cpp_separator())
+
+
+def get_cpp_footer():
+ """Retrieves the C++ footer"""
+ return "\n} // namespace clblast\n"
+
+
+def get_cpp_precision(family, precision):
+ """Retrieves the C++ code for the start of a new precision"""
+ precision_string = precision_to_string(precision)
+ return("\n\nconst Database::DatabaseEntry Database::%s%s = {\n \"%s\", Precision::k%s, {\n"
+ % (family.title(), precision_string, family.title(), precision_string))
+
+
+def get_cpp_device_vendor(vendor, device_type):
+ """Retrieves the C++ code for the (default) vendor and device type"""
+ if vendor == VENDOR_DEFAULT and device_type == DEVICE_TYPE_DEFAULT:
+ return " { // Default\n kDeviceType%s, \"%s\", {\n" % (device_type, vendor)
+ device_type_caps = device_type[0].upper() + device_type[1:]
+ return " { // %s %ss\n kDeviceType%s, \"%s\", {\n" % (vendor, device_type, device_type_caps, vendor)
+
+
+def print_cpp_database(database, output_dir):
+ """Outputs the database as C++ code"""
+
+ # Iterates over the kernel families
+ for family_name, family_database in database.groupby(["kernel_family"]):
+ family_database = family_database.dropna(axis=1, how='all')
+
+ # Opens a new file for each kernel family
+ full_path = os.path.join(output_dir, family_name+'.hpp')
+ with open(full_path, 'w+') as f:
+ f.write(get_cpp_header(family_name))
+
+ # Loops over the different precision (e.g. 16, 32, 3232, 64, 6464)
+ for precision, precision_database in family_database.groupby(["precision"]):
+ f.write(get_cpp_precision(family_name, precision))
+
+ # Loops over a combination of device vendors and device types (e.g. AMD GPU)
+ for vendor, vendor_database in precision_database.groupby(["device_vendor"]):
+ for device_type, device_type_database in vendor_database.groupby(["device_type"]):
+ f.write(get_cpp_device_vendor(vendor, device_type))
+
+ # Loops over every device of this vendor-type combination
+ for device_name, device_database in device_type_database.groupby(["device"]):
+ device_name_quoted = "\"%s\"," % device_name
+ device_name_cpp = " { %-50s { " % device_name_quoted
+ f.write(device_name_cpp)
+
+ # Collects the parameters for this entry
+ parameters = []
+ for kernel, kernel_database in device_database.groupby(["kernel"]):
+ kernel_database = kernel_database.dropna(axis=1)
+
+ # Only consider the actual parameters, not the precision
+ def is_parameter(column):
+ return column.startswith('parameters.') and column != "parameters.PRECISION"
+ column_names = [col for col in list(kernel_database) if is_parameter(col)]
+
+ for p in column_names:
+ parameter_name = p.replace("parameters.", "")
+ parameter_value = int(kernel_database[p].iloc[0])
+ parameters.append("{\"" + parameter_name + "\"," + str(parameter_value) + "}")
+
+ # Prints the entry
+ f.write(", ".join(parameters))
+ f.write(" } },\n")
+
+ # Prints the vendor-type combination footer
+ f.write(" }\n },\n")
+
+ # Prints the precision footer
+ f.write(" }\n};\n\n" + get_cpp_separator())
+
+ # Prints the file footer
+ f.write(get_cpp_footer())
diff --git a/scripts/database/database/db.py b/scripts/database/database/db.py
new file mode 100644
index 00000000..60cfbcfa
--- /dev/null
+++ b/scripts/database/database/db.py
@@ -0,0 +1,50 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import pandas as pd
+
+
+def get_entries_by_field(database, field, value):
+ """Retrieves entries from the database with a specific value for a given field"""
+ return database[database[field] == value]
+
+
+def concatenate_database(database1, database2):
+ """Concatenates two databases row-wise and returns the result"""
+ return pd.concat([database1, database2])
+
+
+def remove_duplicates(database):
+ """Removes duplicates from a database"""
+ return database.drop_duplicates()
+
+
+def find_and_replace(database, dictionary):
+ """Finds and replaces entries in a database based on a dictionary. Example:
+ dictionary = { "key_to_edit": { find1: replace1, find2, replace2 } }"""
+ return database.replace(dictionary)
+
+
+def remove_entries_by_key_value(database, key, value):
+ """Removes entries in the databased which have a specific value for a given key"""
+ return database[database[key] != value]
+
+
+def remove_entries_by_device(database, device_name):
+ """Shorthand for the above, specifically removes entries for a given device"""
+ return remove_entries_by_key_value(database, "device", device_name)
+
+
+def remove_entries_by_kernel_family(database, kernel_family_name):
+ """Shorthand for the above, specifically removes entries for a given kernel family"""
+ return remove_entries_by_key_value(database, "kernel_family", kernel_family_name)
+
+
+def update_database(database, condition, field, value):
+ """Updates the database by writing a specific value to a given field, given certain conditions"""
+ database.loc[condition, field] = value
+ return database
diff --git a/scripts/database/database/defaults.py b/scripts/database/database/defaults.py
new file mode 100644
index 00000000..357c3a3a
--- /dev/null
+++ b/scripts/database/database/defaults.py
@@ -0,0 +1,58 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import pandas as pd
+import clblast
+
+
+def set_default_device(database_entry):
+ """Sets the device name and parameters to some default values"""
+ database_entry["device"] = clblast.DEVICE_NAME_DEFAULT
+ database_entry["device_compute_units"] = 0
+ database_entry["device_core_clock"] = 0
+ return database_entry
+
+
+def set_default_time(database_entry):
+ """Sets the execution time to some default value"""
+ database_entry["time"] = 0.0
+ return database_entry
+
+
+def calculate_defaults(df):
+ """# Sets defaults for devices of the same type/vendor based on the smallest values of all known entries. The average
+ might be better for performance but some parameters might not be supported on other devices."""
+ database_defaults = pd.DataFrame()
+
+ # Defaults per combination of device vendors and device types (e.g. AMD GPU)
+ database_type_vendor = df.groupby(clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"] +
+ clblast.ARGUMENT_ATTRIBUTES)
+ for group_name, database_group in database_type_vendor:
+ default_values = database_group.min(axis=0)
+ default_values = set_default_device(default_values)
+ default_values = set_default_time(default_values)
+ database_defaults = database_defaults.append(default_values, ignore_index=True)
+
+ # Checks for mis-matched arguments
+ groups = database_defaults.groupby(clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"])
+ for group_name, database_group in groups:
+ if len(database_group) != 1:
+ description = database_group["kernel"].min() + " " + database_group["device_vendor"].min()
+ print("[WARNING] Entries for a single kernel with multiple argument values: " + description)
+
+ # Defaults over all device types and vendors
+ groups = df.groupby(clblast.KERNEL_ATTRIBUTES + ["kernel"] + clblast.ARGUMENT_ATTRIBUTES)
+ for group_name, database_group in groups:
+ default_values = database_group.min(axis=0)
+ default_values["device_vendor"] = clblast.VENDOR_DEFAULT
+ default_values["device_type"] = clblast.DEVICE_TYPE_DEFAULT
+ default_values = set_default_device(default_values)
+ default_values = set_default_time(default_values)
+ database_defaults = database_defaults.append(default_values, ignore_index=True)
+
+ # Database with both types of defaults only
+ return database_defaults
diff --git a/scripts/database/database/io.py b/scripts/database/database/io.py
new file mode 100644
index 00000000..ad2f7ae9
--- /dev/null
+++ b/scripts/database/database/io.py
@@ -0,0 +1,58 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import re
+import json
+
+try:
+ from urllib.request import urlopen # Python 3
+except ImportError:
+ from urllib2 import urlopen # Python 2
+
+import pandas as pd
+
+import clblast
+
+
+def download_database(filename, database_url):
+ """Downloads a database and saves it to disk"""
+ print("[database] Downloading database from '" + database_url + "'...")
+ database = urlopen(database_url)
+ with open(filename, 'wb') as f:
+ f.write(database.read())
+
+
+def load_database(filename):
+ """Loads a database from disk"""
+ print("[database] Loading database from '" + filename + "'")
+ return pd.read_pickle(filename)
+
+
+def save_database(database, filename):
+ """Saves a database to disk"""
+ print("[database] Saving database to '" + filename + "'")
+ database.to_pickle(filename)
+
+
+def load_json_to_pandas(filename):
+ """Loads JSON data from file and converts it to a pandas database"""
+ with open(filename) as f:
+ json_data = json.load(f)
+
+ # Gathers all results and stores them in a new database
+ json_database = pd.DataFrame(json_data)
+ new_database = pd.io.json.json_normalize(json_database["results"])
+
+ # Sets the common attributes to each entry in the results
+ for attribute in clblast.ATTRIBUTES:
+ if attribute == "kernel_family":
+ new_database[attribute] = re.sub(r'_\d+', '', json_data[attribute])
+ elif attribute in json_data:
+ new_database[attribute] = json_data[attribute]
+ else:
+ new_database[attribute] = 0 # For example a parameters that was not used by this kernel
+ return new_database
diff --git a/src/clpp11.hpp b/src/clpp11.hpp
index af9d2ea4..d57223dd 100644
--- a/src/clpp11.hpp
+++ b/src/clpp11.hpp
@@ -109,7 +109,9 @@ class Event {
// Accessor to the private data-member
cl_event& operator()() { return *event_; }
+ const cl_event& operator()() const { return *event_; }
cl_event* pointer() { return &(*event_); }
+ const cl_event* pointer() const { return &(*event_); }
private:
std::shared_ptr<cl_event> event_;
};
@@ -686,30 +688,21 @@ class Kernel {
// As above, but with an event waiting list
void Launch(const Queue &queue, const std::vector<size_t> &global,
const std::vector<size_t> &local, EventPointer event,
- std::vector<Event>& waitForEvents) {
- if (waitForEvents.size() == 0) { return Launch(queue, global, local, event); }
-
+ const std::vector<Event> &waitForEvents) {
// Builds a plain version of the events waiting list
auto waitForEventsPlain = std::vector<cl_event>();
for (auto &waitEvent : waitForEvents) {
- waitForEventsPlain.push_back(waitEvent());
+ if (waitEvent()) { waitForEventsPlain.push_back(waitEvent()); }
}
// Launches the kernel while waiting for other events
CheckError(clEnqueueNDRangeKernel(queue(), *kernel_, static_cast<cl_uint>(global.size()),
- nullptr, global.data(), local.data(),
+ nullptr, global.data(), !local.empty() ? local.data() : nullptr,
static_cast<cl_uint>(waitForEventsPlain.size()),
- waitForEventsPlain.data(),
+ !waitForEventsPlain.empty() ? waitForEventsPlain.data() : nullptr,
event));
}
- // As above, but with the default local workgroup size
- void Launch(const Queue &queue, const std::vector<size_t> &global, EventPointer event) {
- CheckError(clEnqueueNDRangeKernel(queue(), *kernel_, static_cast<cl_uint>(global.size()),
- nullptr, global.data(), nullptr,
- 0, nullptr, event));
- }
-
// Accessor to the private data-member
const cl_kernel& operator()() const { return *kernel_; }
private:
diff --git a/src/database/database.cpp b/src/database/database.cpp
index 6ec93731..47f1da16 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -42,9 +42,10 @@ const std::vector<Database::DatabaseEntry> Database::database = {
// =================================================================================================
-// Constructor, computing device properties and populating the parameter-vector from the database
+// Constructor, computing device properties and populating the parameter-vector from the database.
+// This takes an optional overlay database in case of custom tuning or custom kernels.
Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
- const Precision precision):
+ const Precision precision, const std::vector<DatabaseEntry> &overlay):
parameters_{} {
// Finds information of the current device
@@ -53,10 +54,26 @@ Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
auto device_vendor = device.Vendor();
auto device_name = device.Name();
+ // Set the short vendor name
+ for (auto &combination : kVendorNames) {
+ if (device_vendor == combination.first) {
+ device_vendor = combination.second;
+ }
+ }
+
// Iterates over all kernels to include, and retrieves the parameters for each of them
for (auto &kernel: kernels) {
- auto search_result = Search(kernel, device_type, device_vendor, device_name, precision);
- parameters_.insert(search_result.begin(), search_result.end());
+ auto search_result = ParametersPtr{};
+
+ for (auto db: { &overlay, &database }) {
+ search_result = Search(kernel, device_type, device_vendor, device_name, precision, *db);
+ if (search_result) {
+ parameters_.insert(search_result->begin(), search_result->end());
+ break;
+ }
+ }
+
+ if (!search_result) { throw std::runtime_error("Database error, could not find a suitable entry"); }
}
}
@@ -73,28 +90,22 @@ std::string Database::GetDefines() const {
// =================================================================================================
-// Searches the database for the right kernel and precision
-Database::Parameters Database::Search(const std::string &this_kernel,
- const std::string &this_type,
- const std::string &this_vendor,
- const std::string &this_device,
- const Precision this_precision) const {
- // Set the short vendor name
- auto this_short_vendor = this_vendor;
- for (auto &combination : kVendorNames) {
- if (this_vendor == combination.first) {
- this_short_vendor = combination.second;
- }
- }
+// Searches a particular database for the right kernel and precision
+Database::ParametersPtr Database::Search(const std::string &this_kernel,
+ const std::string &this_type,
+ const std::string &this_vendor,
+ const std::string &this_device,
+ const Precision this_precision,
+ const std::vector<DatabaseEntry> &this_database) const {
// Selects the right kernel
- for (auto &db: database) {
+ for (auto &db: this_database) {
if (db.kernel == this_kernel && db.precision == this_precision) {
// Searches for the right vendor and device type, or selects the default if unavailable. This
// assumes that the default vendor / device type is last in the database.
for (auto &vendor: db.vendors) {
- if ((vendor.name == this_short_vendor || vendor.name == kDeviceVendorAll) &&
+ if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) &&
(vendor.type == this_type || vendor.type == kDeviceTypeAll)) {
// Searches for the right device. If the current device is unavailable, selects the vendor
@@ -104,7 +115,7 @@ Database::Parameters Database::Search(const std::string &this_kernel,
if (device.name == this_device || device.name == "default") {
// Sets the parameters accordingly
- return device.parameters;
+ return &device.parameters;
}
}
}
@@ -112,8 +123,8 @@ Database::Parameters Database::Search(const std::string &this_kernel,
}
}
- // If we reached this point, something is wrong
- throw std::runtime_error("Database error, could not find a suitable entry");
+ // If we reached this point, the entry was not found in this database
+ return nullptr;
}
// =================================================================================================
diff --git a/src/database/database.hpp b/src/database/database.hpp
index 0987cbed..e84357dc 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -32,6 +32,7 @@ class Database {
// Type alias for the database parameters
using Parameters = std::unordered_map<std::string,size_t>;
+ using ParametersPtr = const Parameters*;
// Structures for content inside the database
struct DatabaseDevice {
@@ -78,9 +79,9 @@ class Database {
static const DatabaseEntry PadtransposeHalf, PadtransposeSingle, PadtransposeDouble, PadtransposeComplexSingle, PadtransposeComplexDouble;
static const std::vector<DatabaseEntry> database;
- // The constructor
+ // The constructor with a user-provided database overlay (potentially an empty vector)
explicit Database(const Queue &queue, const std::vector<std::string> &routines,
- const Precision precision);
+ const Precision precision, const std::vector<DatabaseEntry> &overlay);
// Accessor of values by key
size_t operator[](const std::string key) const { return parameters_.find(key)->second; }
@@ -89,9 +90,10 @@ class Database {
std::string GetDefines() const;
private:
- Parameters Search(const std::string &this_kernel, const std::string &this_type,
- const std::string &this_vendor, const std::string &this_device,
- const Precision this_precision) const;
+ // Search method for a specified database, returning pointer (possibly a nullptr)
+ ParametersPtr Search(const std::string &this_kernel, const std::string &this_type,
+ const std::string &this_vendor, const std::string &this_device,
+ const Precision this_precision, const std::vector<DatabaseEntry> &db) const;
// Found parameters suitable for this device/kernel
Parameters parameters_;
diff --git a/src/routine.cpp b/src/routine.cpp
index 3c3343da..189ae190 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -22,7 +22,8 @@ namespace clblast {
// Constructor: not much here, because no status codes can be returned
Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
- const std::vector<std::string> &routines, const Precision precision):
+ const std::vector<std::string> &routines, const Precision precision,
+ const std::vector<Database::DatabaseEntry> &userDatabase):
precision_(precision),
routine_name_(name),
queue_(queue),
@@ -30,7 +31,7 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
device_name_(device_.Name()),
- db_(queue_, routines, precision_) {
+ db_(queue_, routines, precision_, userDatabase) {
}
// =================================================================================================
diff --git a/src/routine.hpp b/src/routine.hpp
index 54b5779f..f5c607af 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -32,9 +32,11 @@ namespace clblast {
class Routine {
public:
- // Base class constructor
+ // Base class constructor. The user database is an optional extra database to override the
+ // built-in database.
explicit Routine(Queue &queue, EventPointer event, const std::string &name,
- const std::vector<std::string> &routines, const Precision precision);
+ const std::vector<std::string> &routines, const Precision precision,
+ const std::vector<Database::DatabaseEntry> &userDatabase = {});
// Set-up phase of the kernel
StatusCode SetUp();
diff --git a/src/routines/common.cpp b/src/routines/common.cpp
index 2e82e04d..3969cf9f 100644
--- a/src/routines/common.cpp
+++ b/src/routines/common.cpp
@@ -22,23 +22,25 @@ namespace clblast {
// Enqueues a kernel, waits for completion, and checks for errors
StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
std::vector<size_t> global, const std::vector<size_t> &local,
- EventPointer event, std::vector<Event>& waitForEvents) {
+ EventPointer event, const std::vector<Event> &waitForEvents) {
- // Tests for validity of the local thread sizes
- if (local.size() > device.MaxWorkItemDimensions()) {
- return StatusCode::kInvalidLocalNumDimensions;
- }
- const auto max_work_item_sizes = device.MaxWorkItemSizes();
- for (auto i=size_t{0}; i<local.size(); ++i) {
- if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; }
- }
- auto local_size = size_t{1};
- for (auto &item: local) { local_size *= item; }
- if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; }
+ if (!local.empty()) {
+ // Tests for validity of the local thread sizes
+ if (local.size() > device.MaxWorkItemDimensions()) {
+ return StatusCode::kInvalidLocalNumDimensions;
+ }
+ const auto max_work_item_sizes = device.MaxWorkItemSizes();
+ for (auto i=size_t{0}; i<local.size(); ++i) {
+ if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; }
+ }
+ auto local_size = size_t{1};
+ for (auto &item: local) { local_size *= item; }
+ if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; }
- // Make sure the global thread sizes are at least equal to the local sizes
- for (auto i=size_t{0}; i<global.size(); ++i) {
- if (global[i] < local[i]) { global[i] = local[i]; }
+ // Make sure the global thread sizes are at least equal to the local sizes
+ for (auto i=size_t{0}; i<global.size(); ++i) {
+ if (global[i] < local[i]) { global[i] = local[i]; }
+ }
}
// Tests for local memory usage
@@ -69,13 +71,5 @@ StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
return StatusCode::kSuccess;
}
-// As above, but without an event waiting list
-StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
- std::vector<size_t> global, const std::vector<size_t> &local,
- EventPointer event) {
- auto emptyWaitingList = std::vector<Event>();
- return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList);
-}
-
// =================================================================================================
} // namespace clblast
diff --git a/src/routines/common.hpp b/src/routines/common.hpp
index d53bdc25..9d8849c3 100644
--- a/src/routines/common.hpp
+++ b/src/routines/common.hpp
@@ -29,12 +29,7 @@ namespace clblast {
// Enqueues a kernel, waits for completion, and checks for errors
StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
std::vector<size_t> global, const std::vector<size_t> &local,
- EventPointer event, std::vector<Event>& waitForEvents);
-
-// As above, but without an event waiting list
-StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
- std::vector<size_t> global, const std::vector<size_t> &local,
- EventPointer event);
+ EventPointer event, const std::vector<Event> &waitForEvents = {});
// =================================================================================================
@@ -43,7 +38,7 @@ StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
template <typename T>
StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device,
const Database &db,
- EventPointer event, std::vector<Event>& waitForEvents,
+ EventPointer event, const std::vector<Event> &waitForEvents,
const size_t src_one, const size_t src_two,
const size_t src_ld, const size_t src_offset,
const Buffer<T> &src,
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index 0db28537..fce59622 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -63,9 +63,12 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
(layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
const auto c_rotated = (layout == Layout::kRowMajor);
- const auto a_do_transpose = a_rotated;
- const auto b_do_transpose = !b_rotated;
- const auto c_do_transpose = c_rotated;
+ static const auto a_want_rotated = false;
+ static const auto b_want_rotated = true;
+ static const auto c_want_rotated = false;
+ const auto a_do_transpose = a_rotated != a_want_rotated;
+ const auto b_do_transpose = b_rotated != b_want_rotated;
+ const auto c_do_transpose = c_rotated != c_want_rotated;
// In case of complex data-types, the transpose can also become a conjugate transpose
const auto a_conjugate = (a_transpose == Transpose::kConjugate);
@@ -99,6 +102,15 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
const auto n_ceiled = Ceil(n, db_["NWG"]);
const auto k_ceiled = Ceil(k, db_["KWG"]);
+ // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
+ // whether the matrices need to be rotated or not for the kernel.
+ const auto a_one_i = (a_want_rotated) ? k_ceiled : m_ceiled;
+ const auto a_two_i = (a_want_rotated) ? m_ceiled : k_ceiled;
+ const auto b_one_i = (b_want_rotated) ? n_ceiled : k_ceiled;
+ const auto b_two_i = (b_want_rotated) ? k_ceiled : n_ceiled;
+ const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled;
+ const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled;
+
// The padded/transposed input/output matrices: if memory allocation fails, throw an exception
try {
@@ -106,17 +118,17 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
// Determines whether or not temporary matrices are needed
- auto a_no_temp = a_one == m_ceiled && a_two == k_ceiled && a_ld == m_ceiled && a_offset == 0 &&
+ auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offset == 0 &&
a_do_transpose == false && a_conjugate == false;
- auto b_no_temp = b_one == n_ceiled && b_two == k_ceiled && b_ld == n_ceiled && b_offset == 0 &&
+ auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offset == 0 &&
b_do_transpose == false && b_conjugate == false;
- auto c_no_temp = c_one == m_ceiled && c_two == n_ceiled && c_ld == m_ceiled && c_offset == 0 &&
+ auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offset == 0 &&
c_do_transpose == false;
// Creates the temporary matrices
- const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, k_ceiled*m_ceiled);
- const auto b_temp = (b_no_temp) ? b_buffer : Buffer<T>(context_, k_ceiled*n_ceiled);
- const auto c_temp = (c_no_temp) ? c_buffer : Buffer<T>(context_, m_ceiled*n_ceiled);
+ const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, a_one_i*a_two_i);
+ const auto b_temp = (b_no_temp) ? b_buffer : Buffer<T>(context_, b_one_i*b_two_i);
+ const auto c_temp = (c_no_temp) ? c_buffer : Buffer<T>(context_, c_one_i*c_two_i);
// Events of all kernels (including pre/post processing kernels)
auto eventWaitList = std::vector<Event>();
@@ -129,7 +141,7 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
auto eventProcessA = Event();
status = PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
a_one, a_two, a_ld, a_offset, a_buffer,
- m_ceiled, k_ceiled, m_ceiled, 0, a_temp,
+ a_one_i, a_two_i, a_one_i, 0, a_temp,
ConstantOne<T>(), program,
true, a_do_transpose, a_conjugate);
if (ErrorIn(status)) { return status; }
@@ -141,7 +153,7 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
auto eventProcessB = Event();
status = PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
b_one, b_two, b_ld, b_offset, b_buffer,
- n_ceiled, k_ceiled, n_ceiled, 0, b_temp,
+ b_one_i, b_two_i, b_one_i, 0, b_temp,
ConstantOne<T>(), program,
true, b_do_transpose, b_conjugate);
if (ErrorIn(status)) { return status; }
@@ -153,7 +165,7 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
auto eventProcessC = Event();
status = PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
c_one, c_two, c_ld, c_offset, c_buffer,
- m_ceiled, n_ceiled, m_ceiled, 0, c_temp,
+ c_one_i, c_two_i, c_one_i, 0, c_temp,
ConstantOne<T>(), program,
true, c_do_transpose, false);
if (ErrorIn(status)) { return status; }
@@ -176,8 +188,8 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
// Computes the global and local thread sizes
const auto global = std::vector<size_t>{
- (m_ceiled * db_["MDIMC"]) / db_["MWG"],
- (n_ceiled * db_["NDIMC"]) / db_["NWG"]
+ (c_one_i * db_["MDIMC"]) / db_["MWG"],
+ (c_two_i * db_["NDIMC"]) / db_["NWG"]
};
const auto local = std::vector<size_t>{db_["MDIMC"], db_["NDIMC"]};
@@ -191,7 +203,7 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
if (!c_no_temp) {
eventWaitList.push_back(eventKernel);
status = PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList,
- m_ceiled, n_ceiled, m_ceiled, 0, c_temp,
+ c_one_i, c_two_i, c_one_i, 0, c_temp,
c_one, c_two, c_ld, c_offset, c_buffer,
ConstantOne<T>(), program,
false, c_do_transpose, false);