summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-11 23:09:48 +0100
committerGitHub <noreply@github.com>2018-01-11 23:09:48 +0100
commit6d52eb2956c294a46b43627c468813cf64f60b99 (patch)
treecaa7d889091d004d066772083a157a38f9e6d967
parent9b084d04093fdbfb22ee4790c6b3db5c55cd2719 (diff)
parent90e8e55acb4f4c059317fb7e812389a9704b2cb3 (diff)
Merge pull request #240 from CNugteren/retrieve_tuning_parameters
Retrieve tuning parameters
-rw-r--r--CHANGELOG3
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/clblast.h5
-rw-r--r--include/clblast_cuda.h5
-rwxr-xr-xscripts/generator/generator.py2
-rw-r--r--src/api_common.cpp28
-rw-r--r--src/database/database.hpp1
-rw-r--r--test/correctness/misc/override_parameters.cpp3
-rw-r--r--test/correctness/misc/retrieve_parameters.cpp84
9 files changed, 129 insertions, 4 deletions
diff --git a/CHANGELOG b/CHANGELOG
index e4205894..83ba7b07 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -10,6 +10,7 @@ Development (next version)
- Improved compilation time by splitting the tuning database into multiple compilation units
- Various minor fixes and enhancements
- Added tuned parameters for various devices (see README)
+- Added the RetrieveParameters function to the API to be able to inspect the tuning parameters
- Added a strided-batched (not part of the BLAS standard) routine, faster but less generic compared
to the existing xGEMMBATCHED routines:
* SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
@@ -70,7 +71,7 @@ Version 0.11.0
- Replaced the R graph scripts with Python/Matplotlib scripts
- Various minor fixes and enhancements
- Added tuned parameters for various devices (see README)
-- Added the OverrideParameters function to the API to be able to supply custom tuning parmeters
+- Added the OverrideParameters function to the API to be able to supply custom tuning parameters
- Added triangular solver (level-2 & level-3) routines:
* STRSV/DTRSV/CTRSV/ZTRSV (experimental, un-optimized)
* STRSM/DTRSM/CTRSM/ZTRSM (experimental, un-optimized)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 64f258c5..18254658 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -578,7 +578,7 @@ if(TESTS)
endforeach()
# Miscellaneous tests
- set(MISC_TESTS override_parameters)
+ set(MISC_TESTS override_parameters retrieve_parameters)
if(NOT CUDA)
set(MISC_TESTS ${MISC_TESTS} preprocessor)
endif()
diff --git a/include/clblast.h b/include/clblast.h
index 8e3e64da..c4ff5290 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -682,6 +682,11 @@ StatusCode PUBLIC_API FillCache(const cl_device_id device);
// =================================================================================================
+// Retrieves current tuning parameters for a specific device-precision-kernel combination
+StatusCode PUBLIC_API RetrieveParameters(const cl_device_id device, const std::string &kernel_name,
+ const Precision precision,
+ std::unordered_map<std::string,size_t> &parameters);
+
// Overrides tuning parameters for a specific device-precision-kernel combination. The next time
// the target routine is called it will re-compile and use the new parameters from then on.
StatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const std::string &kernel_name,
diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h
index b0cb9aa8..ed348efe 100644
--- a/include/clblast_cuda.h
+++ b/include/clblast_cuda.h
@@ -654,6 +654,11 @@ StatusCode PUBLIC_API FillCache(const CUdevice device);
// =================================================================================================
+// Retrieves current tuning parameters for a specific device-precision-kernel combination
+StatusCode PUBLIC_API RetrieveParameters(const CUdevice device, const std::string &kernel_name,
+ const Precision precision,
+ std::unordered_map<std::string,size_t> &parameters);
+
// Overrides tuning parameters for a specific device-precision-kernel combination. The next time
// the target routine is called it will re-compile and use the new parameters from then on.
StatusCode PUBLIC_API OverrideParameters(const CUdevice device, const std::string &kernel_name,
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 528e61dd..b77b861e 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -47,7 +47,7 @@ FILES = [
"/src/clblast_cuda.cpp",
]
HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
-FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 36, 55]
+FOOTER_LINES = [41, 56, 27, 38, 6, 6, 6, 9, 2, 41, 55]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
diff --git a/src/api_common.cpp b/src/api_common.cpp
index 0d387cd9..4e08f1ef 100644
--- a/src/api_common.cpp
+++ b/src/api_common.cpp
@@ -112,6 +112,34 @@ StatusCode FillCache(const RawDeviceID device) {
// =================================================================================================
+// Retrieves the current tuning parameters for this device-precision-kernel combination
+StatusCode RetrieveParameters(const RawDeviceID device, const std::string &kernel_name,
+ const Precision precision,
+ std::unordered_map<std::string,size_t> &parameters) {
+ try {
+
+ // Retrieves the device name
+ const auto device_cpp = Device(device);
+ const auto platform_id = device_cpp.PlatformID();
+ const auto device_name = GetDeviceName(device_cpp);
+
+ // Retrieves the database values
+ auto in_cache = false;
+ auto database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache);
+ if (!in_cache) {
+ log_debug("Searching database for kernel '" + kernel_name + "'");
+ database = Database(device_cpp, kernel_name, precision, {});
+ }
+
+ // Retrieves the parameters
+ for (const auto &parameter: database.GetParameters()) {
+ parameters[parameter.first] = parameter.second;
+ }
+
+ } catch (...) { return DispatchException(); }
+ return StatusCode::kSuccess;
+}
+
// Overrides the tuning parameters for this device-precision-kernel combination
StatusCode OverrideParameters(const RawDeviceID device, const std::string &kernel_name,
const Precision precision,
diff --git a/src/database/database.hpp b/src/database/database.hpp
index 8e53e013..1db2c286 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -56,6 +56,7 @@ class Database {
// Retrieves the values or names of all the parameters
std::string GetValuesString() const;
std::vector<std::string> GetParameterNames() const;
+ const database::Parameters& GetParameters() const { return *parameters_; }
private:
// Search method functions, returning a set of parameters (possibly empty)
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp
index 05f40f57..1edfb2ba 100644
--- a/test/correctness/misc/override_parameters.cpp
+++ b/test/correctness/misc/override_parameters.cpp
@@ -86,11 +86,12 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st
auto device_a = Buffer<T>(context, host_a.size());
auto device_b = Buffer<T>(context, host_b.size());
auto device_c = Buffer<T>(context, host_c.size());
+ auto device_temp = Buffer<T>(context, args.m * args.n * args.k); // just to be safe
device_a.Write(queue, host_a.size(), host_a);
device_b.Write(queue, host_b.size(), host_b);
device_c.Write(queue, host_c.size(), host_c);
auto dummy = Buffer<T>(context, 1);
- auto buffers = Buffers<T>{dummy, dummy, device_a, device_b, device_c, dummy, dummy};
+ auto buffers = Buffers<T>{dummy, dummy, device_a, device_b, device_c, device_temp, dummy};
// Loops over the valid combinations: run before and run afterwards
fprintf(stdout, "* Testing OverrideParameters for '%s'\n", routine_name.c_str());
diff --git a/test/correctness/misc/retrieve_parameters.cpp b/test/correctness/misc/retrieve_parameters.cpp
new file mode 100644
index 00000000..568dab0d
--- /dev/null
+++ b/test/correctness/misc/retrieve_parameters.cpp
@@ -0,0 +1,84 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file contains the tests for the RetrieveParameters function
+//
+// =================================================================================================
+
+#include <string>
+#include <vector>
+#include <unordered_map>
+#include <iostream>
+
+#include "utilities/utilities.hpp"
+
+namespace clblast {
+// =================================================================================================
+
+template <typename T>
+size_t RunRetrieveParametersTests(int argc, char *argv[], const bool silent, const std::string &routine_name) {
+ auto arguments = RetrieveCommandLineArguments(argc, argv);
+ auto errors = size_t{0};
+ auto passed = size_t{0};
+
+ // Retrieves the arguments
+ auto help = std::string{"Options given/available:\n"};
+ const auto platform_id = GetArgument(arguments, help, kArgPlatform, ConvertArgument(std::getenv("CLBLAST_PLATFORM"), size_t{0}));
+ const auto device_id = GetArgument(arguments, help, kArgDevice, ConvertArgument(std::getenv("CLBLAST_DEVICE"), size_t{0}));
+ auto args = Arguments<T>{};
+
+ // Determines the test settings
+ const auto kernel_name = std::string{"Xgemm"};
+ const auto expected_parameters = std::vector<std::string>{
+ "KWG", "KWI", "MDIMA", "MDIMC", "MWG", "NDIMB", "NDIMC", "NWG", "SA", "SB", "STRM", "STRN", "VWM", "VWN"
+ };
+ const auto expected_max_value = size_t{16384};
+
+ // Prints the help message (command-line arguments)
+ if (!silent) { fprintf(stdout, "\n* %s\n", help.c_str()); }
+
+ // Initializes OpenCL
+ const auto platform = Platform(platform_id);
+ const auto device = Device(platform, device_id);
+
+ // Retrieves the parameters
+ fprintf(stdout, "* Testing RetrieveParameters for '%s'\n", routine_name.c_str());
+ auto parameters = std::unordered_map<std::string,size_t>();
+ const auto status = RetrieveParameters(device(), kernel_name, PrecisionValue<T>(), parameters);
+ if (status != StatusCode::kSuccess) { errors++; }
+
+ // Verifies the parameters
+ for (const auto &expected_parameter : expected_parameters) {
+ if (parameters.find(expected_parameter) != parameters.end()) {
+ const auto value = parameters[expected_parameter];
+ if (value < expected_max_value) { passed++; } else { errors++; }
+ //std::cout << expected_parameter << " = " << value << std::endl;
+ }
+ else { errors++; }
+ }
+
+ // Prints and returns the statistics
+ std::cout << " " << passed << " test(s) passed" << std::endl;
+ std::cout << " " << errors << " test(s) failed" << std::endl;
+ std::cout << std::endl;
+ return errors;
+}
+
+// =================================================================================================
+} // namespace clblast
+
+// Main function (not within the clblast namespace)
+int main(int argc, char *argv[]) {
+ auto errors = size_t{0};
+ errors += clblast::RunRetrieveParametersTests<float>(argc, argv, false, "SGEMM");
+ errors += clblast::RunRetrieveParametersTests<clblast::float2>(argc, argv, true, "CGEMM");
+ if (errors > 0) { return 1; } else { return 0; }
+}
+
+// =================================================================================================