diff options
Diffstat (limited to 'src/routine.cc')
-rw-r--r-- | src/routine.cc | 98 |
1 files changed, 59 insertions, 39 deletions
diff --git a/src/routine.cc b/src/routine.cc index aded1a31..31476c42 100644 --- a/src/routine.cc +++ b/src/routine.cc @@ -13,17 +13,17 @@ #include "internal/routine.h" -#include "internal/utilities.h" - namespace clblast { // ================================================================================================= // The cache of compiled OpenCL programs -std::vector<Routine::ProgramCache> Routine::program_cache_; +template <typename T> +std::vector<typename Routine<T>::ProgramCache> Routine<T>::program_cache_; // Constructor: not much here, because no status codes can be returned -Routine::Routine(CommandQueue &queue, Event &event, const std::string &name, - const std::vector<std::string> &routines, const Precision precision): +template <typename T> +Routine<T>::Routine(Queue &queue, Event &event, const std::string &name, + const std::vector<std::string> &routines, const Precision precision): precision_(precision), routine_name_(name), queue_(queue), @@ -40,14 +40,15 @@ Routine::Routine(CommandQueue &queue, Event &event, const std::string &name, // ================================================================================================= // Separate set-up function to allow for status codes to be returned -StatusCode Routine::SetUp() { +template <typename T> +StatusCode Routine<T>::SetUp() { // Queries the cache to see whether or not the compiled kernel is already there. If not, it will // be built and added to the cache. if (!ProgramIsInCache()) { // Inspects whether or not cl_khr_fp64 is supported in case of double precision - auto extensions = device_.Extensions(); + auto extensions = device_.Capabilities(); if (precision_ == Precision::kDouble || precision_ == Precision::kComplexDouble) { if (extensions.find(kKhronosDoublePrecision) == std::string::npos) { return StatusCode::kNoDoublePrecision; @@ -85,16 +86,16 @@ StatusCode Routine::SetUp() { // Compiles the kernel try { auto program = Program(context_, source_string); - auto options = std::string{}; - auto status = program.Build(device_, options); + auto options = std::vector<std::string>(); + auto build_status = program.Build(device_, options); // Checks for compiler crashes/errors/warnings - if (status == CL_BUILD_PROGRAM_FAILURE) { + if (build_status == BuildStatus::kError) { auto message = program.GetBuildInfo(device_); fprintf(stdout, "OpenCL compiler error/warning: %s\n", message.c_str()); return StatusCode::kBuildProgramFailure; } - if (status == CL_INVALID_BINARY) { return StatusCode::kInvalidBinary; } + if (build_status == BuildStatus::kInvalid) { return StatusCode::kInvalidBinary; } // Store the compiled program in the cache program_cache_.push_back({program, device_name_, precision_, routine_name_}); @@ -108,8 +109,9 @@ StatusCode Routine::SetUp() { // ================================================================================================= // Enqueues a kernel, waits for completion, and checks for errors -StatusCode Routine::RunKernel(const Kernel &kernel, std::vector<size_t> &global, - const std::vector<size_t> &local) { +template <typename T> +StatusCode Routine<T>::RunKernel(Kernel &kernel, std::vector<size_t> &global, + const std::vector<size_t> &local) { // Tests for validity of the local thread sizes if (local.size() > max_work_item_dimensions_) { @@ -132,12 +134,14 @@ StatusCode Routine::RunKernel(const Kernel &kernel, std::vector<size_t> &global, if (!device_.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; } // Launches the kernel (and checks for launch errors) - auto status = queue_.EnqueueKernel(kernel, global, local, event_); - if (status != CL_SUCCESS) { return StatusCode::kKernelLaunchError; } + try { + kernel.Launch(queue_, global, local, event_); + } catch (...) { return StatusCode::kKernelLaunchError; } // Waits for completion of the kernel - status = event_.Wait(); - if (status != CL_SUCCESS) { return StatusCode::kKernelRunError; } + try { + queue_.Finish(event_); + } catch (...) { return StatusCode::kKernelRunError; } // No errors, normal termination of this function return StatusCode::kSuccess; @@ -147,8 +151,9 @@ StatusCode Routine::RunKernel(const Kernel &kernel, std::vector<size_t> &global, // Tests matrix A for validity: checks for a valid OpenCL buffer, a valid lead-dimension, and for a // sufficient buffer size. -StatusCode Routine::TestMatrixA(const size_t one, const size_t two, const Buffer &buffer, - const size_t offset, const size_t ld, const size_t data_size) { +template <typename T> +StatusCode Routine<T>::TestMatrixA(const size_t one, const size_t two, const Buffer<T> &buffer, + const size_t offset, const size_t ld, const size_t data_size) { if (ld < one) { return StatusCode::kInvalidLeadDimA; } try { auto required_size = (ld*two + offset)*data_size; @@ -160,8 +165,9 @@ StatusCode Routine::TestMatrixA(const size_t one, const size_t two, const Buffer // Tests matrix B for validity: checks for a valid OpenCL buffer, a valid lead-dimension, and for a // sufficient buffer size. -StatusCode Routine::TestMatrixB(const size_t one, const size_t two, const Buffer &buffer, - const size_t offset, const size_t ld, const size_t data_size) { +template <typename T> +StatusCode Routine<T>::TestMatrixB(const size_t one, const size_t two, const Buffer<T> &buffer, + const size_t offset, const size_t ld, const size_t data_size) { if (ld < one) { return StatusCode::kInvalidLeadDimB; } try { auto required_size = (ld*two + offset)*data_size; @@ -173,8 +179,9 @@ StatusCode Routine::TestMatrixB(const size_t one, const size_t two, const Buffer // Tests matrix C for validity: checks for a valid OpenCL buffer, a valid lead-dimension, and for a // sufficient buffer size. -StatusCode Routine::TestMatrixC(const size_t one, const size_t two, const Buffer &buffer, - const size_t offset, const size_t ld, const size_t data_size) { +template <typename T> +StatusCode Routine<T>::TestMatrixC(const size_t one, const size_t two, const Buffer<T> &buffer, + const size_t offset, const size_t ld, const size_t data_size) { if (ld < one) { return StatusCode::kInvalidLeadDimC; } try { auto required_size = (ld*two + offset)*data_size; @@ -188,8 +195,9 @@ StatusCode Routine::TestMatrixC(const size_t one, const size_t two, const Buffer // Tests vector X for validity: checks for a valid increment, a valid OpenCL buffer, and for a // sufficient buffer size. -StatusCode Routine::TestVectorX(const size_t n, const Buffer &buffer, const size_t offset, - const size_t inc, const size_t data_size) { +template <typename T> +StatusCode Routine<T>::TestVectorX(const size_t n, const Buffer<T> &buffer, const size_t offset, + const size_t inc, const size_t data_size) { if (inc == 0) { return StatusCode::kInvalidIncrementX; } try { auto required_size = (n*inc + offset)*data_size; @@ -201,8 +209,9 @@ StatusCode Routine::TestVectorX(const size_t n, const Buffer &buffer, const size // Tests vector Y for validity: checks for a valid increment, a valid OpenCL buffer, and for a // sufficient buffer size. -StatusCode Routine::TestVectorY(const size_t n, const Buffer &buffer, const size_t offset, - const size_t inc, const size_t data_size) { +template <typename T> +StatusCode Routine<T>::TestVectorY(const size_t n, const Buffer<T> &buffer, const size_t offset, + const size_t inc, const size_t data_size) { if (inc == 0) { return StatusCode::kInvalidIncrementY; } try { auto required_size = (n*inc + offset)*data_size; @@ -215,16 +224,17 @@ StatusCode Routine::TestVectorY(const size_t n, const Buffer &buffer, const size // ================================================================================================= // Copies or transposes a matrix and pads/unpads it with zeros -StatusCode Routine::PadCopyTransposeMatrix(const size_t src_one, const size_t src_two, - const size_t src_ld, const size_t src_offset, - const Buffer &src, - const size_t dest_one, const size_t dest_two, - const size_t dest_ld, const size_t dest_offset, - const Buffer &dest, - const Program &program, const bool do_pad, - const bool do_transpose, const bool do_conjugate, - const bool upper, const bool lower, - const bool diagonal_imag_zero) { +template <typename T> +StatusCode Routine<T>::PadCopyTransposeMatrix(const size_t src_one, const size_t src_two, + const size_t src_ld, const size_t src_offset, + const Buffer<T> &src, + const size_t dest_one, const size_t dest_two, + const size_t dest_ld, const size_t dest_offset, + const Buffer<T> &dest, + const Program &program, const bool do_pad, + const bool do_transpose, const bool do_conjugate, + const bool upper, const bool lower, + const bool diagonal_imag_zero) { // Determines whether or not the fast-version could potentially be used auto use_fast_kernel = (src_offset == 0) && (dest_offset == 0) && (do_conjugate == false) && @@ -328,7 +338,8 @@ StatusCode Routine::PadCopyTransposeMatrix(const size_t src_one, const size_t sr // Queries the cache and retrieves a matching program. Assumes that the match is available, throws // otherwise. -const Program& Routine::GetProgramFromCache() const { +template <typename T> +const Program& Routine<T>::GetProgramFromCache() const { for (auto &cached_program: program_cache_) { if (cached_program.MatchInCache(device_name_, precision_, routine_name_)) { return cached_program.program; @@ -338,7 +349,8 @@ const Program& Routine::GetProgramFromCache() const { } // Queries the cache to see whether or not the compiled kernel is already there -bool Routine::ProgramIsInCache() const { +template <typename T> +bool Routine<T>::ProgramIsInCache() const { for (auto &cached_program: program_cache_) { if (cached_program.MatchInCache(device_name_, precision_, routine_name_)) { return true; } } @@ -346,4 +358,12 @@ bool Routine::ProgramIsInCache() const { } // ================================================================================================= + +// Compiles the templated class +template class Routine<float>; +template class Routine<double>; +template class Routine<float2>; +template class Routine<double2>; + +// ================================================================================================= } // namespace clblast |