summaryrefslogtreecommitdiff
path: root/src/routine.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routine.cpp')
-rw-r--r--src/routine.cpp131
1 files changed, 131 insertions, 0 deletions
diff --git a/src/routine.cpp b/src/routine.cpp
new file mode 100644
index 00000000..d3590896
--- /dev/null
+++ b/src/routine.cpp
@@ -0,0 +1,131 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Routine base class (see the header for information about the class).
+//
+// =================================================================================================
+
+#include <string>
+#include <vector>
+
+#include "routine.hpp"
+
+namespace clblast {
+// =================================================================================================
+
+// Constructor: not much here, because no status codes can be returned
+Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
+ const std::vector<std::string> &routines, const Precision precision):
+ precision_(precision),
+ routine_name_(name),
+ queue_(queue),
+ event_(event),
+ context_(queue_.GetContext()),
+ device_(queue_.GetDevice()),
+ device_name_(device_.Name()),
+ db_(queue_, routines, precision_) {
+}
+
+// =================================================================================================
+
+// Separate set-up function to allow for status codes to be returned
+StatusCode Routine::SetUp() {
+
+ // Queries the cache to see whether or not the program (context-specific) is already there
+ if (ProgramIsInCache(context_, precision_, routine_name_)) { return StatusCode::kSuccess; }
+
+ // Queries the cache to see whether or not the binary (device-specific) is already there. If it
+ // is, a program is created and stored in the cache
+ if (BinaryIsInCache(device_name_, precision_, routine_name_)) {
+ try {
+ auto& binary = GetBinaryFromCache(device_name_, precision_, routine_name_);
+ auto program = Program(device_, context_, binary);
+ auto options = std::vector<std::string>();
+ program.Build(device_, options);
+ StoreProgramToCache(program, context_, precision_, routine_name_);
+ } catch (...) { return StatusCode::kBuildProgramFailure; }
+ return StatusCode::kSuccess;
+ }
+
+ // Otherwise, the kernel will be compiled and program will be built. Both the binary and the
+ // program will be added to the cache.
+
+ // Inspects whether or not cl_khr_fp64 is supported in case of double precision
+ const auto extensions = device_.Capabilities();
+ if (precision_ == Precision::kDouble || precision_ == Precision::kComplexDouble) {
+ if (extensions.find(kKhronosDoublePrecision) == std::string::npos) {
+ return StatusCode::kNoDoublePrecision;
+ }
+ }
+
+ // As above, but for cl_khr_fp16 (half precision)
+ if (precision_ == Precision::kHalf) {
+ if (extensions.find(kKhronosHalfPrecision) == std::string::npos) {
+ return StatusCode::kNoHalfPrecision;
+ }
+ }
+
+ // Loads the common header (typedefs and defines and such)
+ std::string common_header =
+ #include "kernels/common.opencl"
+ ;
+
+ // Collects the parameters for this device in the form of defines, and adds the precision
+ auto defines = db_.GetDefines();
+ defines += "#define PRECISION "+ToString(static_cast<int>(precision_))+"\n";
+
+ // Adds the name of the routine as a define
+ defines += "#define ROUTINE_"+routine_name_+"\n";
+
+ // For specific devices, use the non-IEE754 compilant OpenCL mad() instruction. This can improve
+ // performance, but might result in a reduced accuracy.
+ if (device_.IsAMD() && device_.IsGPU()) {
+ defines += "#define USE_CL_MAD 1\n";
+ }
+
+ // For specific devices, use staggered/shuffled workgroup indices.
+ if (device_.IsAMD() && device_.IsGPU()) {
+ defines += "#define USE_STAGGERED_INDICES 1\n";
+ }
+
+ // For specific devices add a global synchronisation barrier to the GEMM kernel to optimize
+ // performance through better cache behaviour
+ if (device_.IsARM() && device_.IsGPU()) {
+ defines += "#define GLOBAL_MEM_FENCE 1\n";
+ }
+
+ // Combines everything together into a single source string
+ const auto source_string = defines + common_header + source_string_;
+
+ // Compiles the kernel
+ try {
+ auto program = Program(context_, source_string);
+ auto options = std::vector<std::string>();
+ const auto build_status = program.Build(device_, options);
+
+ // Checks for compiler crashes/errors/warnings
+ if (build_status == BuildStatus::kError) {
+ const auto message = program.GetBuildInfo(device_);
+ fprintf(stdout, "OpenCL compiler error/warning: %s\n", message.c_str());
+ return StatusCode::kBuildProgramFailure;
+ }
+ if (build_status == BuildStatus::kInvalid) { return StatusCode::kInvalidBinary; }
+
+ // Store the compiled binary and program in the cache
+ const auto binary = program.GetIR();
+ StoreBinaryToCache(binary, device_name_, precision_, routine_name_);
+ StoreProgramToCache(program, context_, precision_, routine_name_);
+ } catch (...) { return StatusCode::kBuildProgramFailure; }
+
+ // No errors, normal termination of this function
+ return StatusCode::kSuccess;
+}
+
+// =================================================================================================
+} // namespace clblast