diff options
Diffstat (limited to 'src/routine.cpp')
-rw-r--r-- | src/routine.cpp | 85 |
1 files changed, 37 insertions, 48 deletions
diff --git a/src/routine.cpp b/src/routine.cpp index 80764b74..acafb0d2 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -21,10 +21,11 @@ namespace clblast { // ================================================================================================= -// Constructor: not much here, because no status codes can be returned +// The constructor does all heavy work, errors are returned as exceptions Routine::Routine(Queue &queue, EventPointer event, const std::string &name, const std::vector<std::string> &routines, const Precision precision, - const std::vector<const Database::DatabaseEntry*> &userDatabase): + const std::vector<const Database::DatabaseEntry*> &userDatabase, + std::initializer_list<const char *> source): precision_(precision), routine_name_(name), queue_(queue), @@ -33,15 +34,9 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, device_(queue_.GetDevice()), device_name_(device_.Name()), db_(queue_, routines, precision_, userDatabase) { -} - -// ================================================================================================= - -// Separate set-up function to allow for status codes to be returned -StatusCode Routine::SetUp() { // Queries the cache to see whether or not the program (context-specific) is already there - if (ProgramIsInCache(context_, precision_, routine_name_)) { return StatusCode::kSuccess; } + if (ProgramIsInCache(context_, precision_, routine_name_)) { return; } // Sets the build options from an environmental variable (if set) auto options = std::vector<std::string>(); @@ -53,13 +48,10 @@ StatusCode Routine::SetUp() { // Queries the cache to see whether or not the binary (device-specific) is already there. If it // is, a program is created and stored in the cache if (BinaryIsInCache(device_name_, precision_, routine_name_)) { - try { - auto& binary = GetBinaryFromCache(device_name_, precision_, routine_name_); - auto program = Program(device_, context_, binary); - program.Build(device_, options); - StoreProgramToCache(program, context_, precision_, routine_name_); - } catch (...) { return StatusCode::kBuildProgramFailure; } - return StatusCode::kSuccess; + auto& binary = GetBinaryFromCache(device_name_, precision_, routine_name_); + auto program = Program(device_, context_, binary); + program.Build(device_, options); + StoreProgramToCache(program, context_, precision_, routine_name_); } // Otherwise, the kernel will be compiled and program will be built. Both the binary and the @@ -69,48 +61,50 @@ StatusCode Routine::SetUp() { const auto extensions = device_.Capabilities(); if (precision_ == Precision::kDouble || precision_ == Precision::kComplexDouble) { if (extensions.find(kKhronosDoublePrecision) == std::string::npos) { - return StatusCode::kNoDoublePrecision; + throw RuntimeErrorCode(StatusCode::kNoDoublePrecision); } } // As above, but for cl_khr_fp16 (half precision) if (precision_ == Precision::kHalf) { if (extensions.find(kKhronosHalfPrecision) == std::string::npos) { - return StatusCode::kNoHalfPrecision; + throw RuntimeErrorCode(StatusCode::kNoHalfPrecision); } } - // Loads the common header (typedefs and defines and such) - std::string common_header = - #include "kernels/common.opencl" - ; - // Collects the parameters for this device in the form of defines, and adds the precision - auto defines = db_.GetDefines(); - defines += "#define PRECISION "+ToString(static_cast<int>(precision_))+"\n"; + auto source_string = db_.GetDefines(); + source_string += "#define PRECISION "+ToString(static_cast<int>(precision_))+"\n"; // Adds the name of the routine as a define - defines += "#define ROUTINE_"+routine_name_+"\n"; + source_string += "#define ROUTINE_"+routine_name_+"\n"; // For specific devices, use the non-IEE754 compilant OpenCL mad() instruction. This can improve // performance, but might result in a reduced accuracy. if (device_.IsAMD() && device_.IsGPU()) { - defines += "#define USE_CL_MAD 1\n"; + source_string += "#define USE_CL_MAD 1\n"; } // For specific devices, use staggered/shuffled workgroup indices. if (device_.IsAMD() && device_.IsGPU()) { - defines += "#define USE_STAGGERED_INDICES 1\n"; + source_string += "#define USE_STAGGERED_INDICES 1\n"; } // For specific devices add a global synchronisation barrier to the GEMM kernel to optimize // performance through better cache behaviour if (device_.IsARM() && device_.IsGPU()) { - defines += "#define GLOBAL_MEM_FENCE 1\n"; + source_string += "#define GLOBAL_MEM_FENCE 1\n"; } - // Combines everything together into a single source string - const auto source_string = defines + common_header + source_string_; + // Loads the common header (typedefs and defines and such) + source_string += + #include "kernels/common.opencl" + ; + + // Adds routine-specific code to the constructed source string + for (const char *s: source) { + source_string += s; + } // Prints details of the routine to compile in case of debugging in verbose mode #ifdef VERBOSE @@ -120,23 +114,21 @@ StatusCode Routine::SetUp() { #endif // Compiles the kernel + auto program = Program(context_, source_string); try { - auto program = Program(context_, source_string); - const auto build_status = program.Build(device_, options); - - // Checks for compiler crashes/errors/warnings - if (build_status == BuildStatus::kError) { - const auto message = program.GetBuildInfo(device_); - fprintf(stdout, "OpenCL compiler error/warning: %s\n", message.c_str()); - return StatusCode::kBuildProgramFailure; + program.Build(device_, options); + } catch (const CLError &e) { + if (e.status() == CL_BUILD_PROGRAM_FAILURE) { + fprintf(stdout, "OpenCL compiler error/warning: %s\n", + program.GetBuildInfo(device_).c_str()); } - if (build_status == BuildStatus::kInvalid) { return StatusCode::kInvalidBinary; } + throw; + } - // Store the compiled binary and program in the cache - const auto binary = program.GetIR(); - StoreBinaryToCache(binary, device_name_, precision_, routine_name_); - StoreProgramToCache(program, context_, precision_, routine_name_); - } catch (...) { return StatusCode::kBuildProgramFailure; } + // Store the compiled binary and program in the cache + const auto binary = program.GetIR(); + StoreBinaryToCache(binary, device_name_, precision_, routine_name_); + StoreProgramToCache(program, context_, precision_, routine_name_); // Prints the elapsed compilation time in case of debugging in verbose mode #ifdef VERBOSE @@ -144,9 +136,6 @@ StatusCode Routine::SetUp() { const auto timing = std::chrono::duration<double,std::milli>(elapsed_time).count(); printf("[DEBUG] Completed compilation in %.2lf ms\n", timing); #endif - - // No errors, normal termination of this function - return StatusCode::kSuccess; } // ================================================================================================= |