diff options
-rw-r--r-- | CHANGELOG | 2 | ||||
-rw-r--r-- | include/internal/routine.h | 4 | ||||
-rw-r--r-- | src/routine.cc | 24 |
3 files changed, 23 insertions, 7 deletions
@@ -1,6 +1,6 @@ Development version (next release) -- +- Made the library thread-safe Version 0.6.0 - Added support for MSVC (Visual Studio) 2015 diff --git a/include/internal/routine.h b/include/internal/routine.h index b7c06a97..5f5b8211 100644 --- a/include/internal/routine.h +++ b/include/internal/routine.h @@ -18,6 +18,7 @@ #include <string> #include <vector> +#include <mutex> #include "internal/utilities.h" #include "internal/database.h" @@ -46,8 +47,9 @@ class Routine { } }; - // The actual cache, implemented as a vector of the above data-type + // The actual cache, implemented as a vector of the above data-type, and its mutex static std::vector<ProgramCache> program_cache_; + static std::mutex program_cache_mutex_; // Helper functions which check for errors in the status code static constexpr bool ErrorIn(const StatusCode s) { return (s != StatusCode::kSuccess); } diff --git a/src/routine.cc b/src/routine.cc index 2978c94a..ff7b3e1a 100644 --- a/src/routine.cc +++ b/src/routine.cc @@ -11,14 +11,18 @@ // // ================================================================================================= +#include <string> +#include <vector> +#include <mutex> + #include "internal/routine.h" namespace clblast { // ================================================================================================= -// The cache of compiled OpenCL programs -template <typename T> -std::vector<typename Routine<T>::ProgramCache> Routine<T>::program_cache_; +// The cache of compiled OpenCL programs and its mutex for thread safety +template <typename T> std::vector<typename Routine<T>::ProgramCache> Routine<T>::program_cache_; +template <typename T> std::mutex Routine<T>::program_cache_mutex_; // Constructor: not much here, because no status codes can be returned template <typename T> @@ -97,8 +101,10 @@ StatusCode Routine<T>::SetUp() { } if (build_status == BuildStatus::kInvalid) { return StatusCode::kInvalidBinary; } - // Store the compiled program in the cache + // Store the compiled program in the cache (atomic for thread-safety) + program_cache_mutex_.lock(); program_cache_.push_back({program, device_name_, precision_, routine_name_}); + program_cache_mutex_.unlock(); } catch (...) { return StatusCode::kBuildProgramFailure; } } @@ -367,20 +373,28 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(const size_t src_one, const size_t // otherwise. template <typename T> const Program& Routine<T>::GetProgramFromCache() const { + program_cache_mutex_.lock(); for (auto &cached_program: program_cache_) { if (cached_program.MatchInCache(device_name_, precision_, routine_name_)) { + program_cache_mutex_.unlock(); return cached_program.program; } } + program_cache_mutex_.unlock(); throw std::runtime_error("Internal CLBlast error: Expected program in cache, but found none."); } // Queries the cache to see whether or not the compiled kernel is already there template <typename T> bool Routine<T>::ProgramIsInCache() const { + program_cache_mutex_.lock(); for (auto &cached_program: program_cache_) { - if (cached_program.MatchInCache(device_name_, precision_, routine_name_)) { return true; } + if (cached_program.MatchInCache(device_name_, precision_, routine_name_)) { + program_cache_mutex_.unlock(); + return true; + } } + program_cache_mutex_.unlock(); return false; } |