diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cache.cpp | 32 | ||||
-rw-r--r-- | src/cache.hpp | 11 | ||||
-rw-r--r-- | src/clblast.cpp | 43 | ||||
-rw-r--r-- | src/clblast_c.cpp | 21 | ||||
-rw-r--r-- | src/database/database.cpp | 32 | ||||
-rw-r--r-- | src/database/database.hpp | 33 | ||||
-rw-r--r-- | src/routine.cpp | 61 | ||||
-rw-r--r-- | src/routine.hpp | 18 | ||||
-rw-r--r-- | src/routines/common.hpp | 2 | ||||
-rw-r--r-- | src/routines/level2/xgemv.cpp | 2 |
10 files changed, 215 insertions, 40 deletions
diff --git a/src/cache.cpp b/src/cache.cpp index c5cc6a4d..4b74b0a1 100644 --- a/src/cache.cpp +++ b/src/cache.cpp @@ -65,6 +65,37 @@ void Cache<Key, Value>::Store(Key &&key, Value &&value) { } template <typename Key, typename Value> +void Cache<Key, Value>::Remove(const Key &key) { + std::lock_guard<std::mutex> lock(cache_mutex_); +#if __cplusplus >= 201402L + cache_.erase(key); +#else + auto it = cache_.begin(); + while (it != cache_.end()) { + if ((*it).first == key) { + it = cache_.erase(it); + } + else ++it; + } +#endif +} + +template <typename Key, typename Value> +template <int I1, int I2> +void Cache<Key, Value>::RemoveBySubset(const Key &key) { + std::lock_guard<std::mutex> lock(cache_mutex_); + auto it = cache_.begin(); + while (it != cache_.end()) { + const auto current_key = (*it).first; + if ((std::get<I1>(key) == std::get<I1>(current_key)) && + (std::get<I2>(key) == std::get<I2>(current_key))) { + it = cache_.erase(it); + } + else ++it; + } +} + +template <typename Key, typename Value> void Cache<Key, Value>::Invalidate() { std::lock_guard<std::mutex> lock(cache_mutex_); @@ -88,6 +119,7 @@ template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const; template class Cache<ProgramKey, Program>; template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const; +template void ProgramCache::RemoveBySubset<1, 2>(const ProgramKey &); // precision and routine name // ================================================================================================= diff --git a/src/cache.hpp b/src/cache.hpp index c3675f07..694de839 100644 --- a/src/cache.hpp +++ b/src/cache.hpp @@ -42,6 +42,10 @@ public: void Store(Key &&key, Value &&value); void Invalidate(); + // Removes all entries with a given key + void Remove(const Key &key); + template <int I1, int I2> void RemoveBySubset(const Key &key); // currently supports 2 indices + static Cache<Key, Value> &Instance(); private: @@ -72,7 +76,6 @@ typedef Cache<BinaryKey, std::string> BinaryCache; extern template class Cache<BinaryKey, std::string>; extern template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const; - // ================================================================================================= // The key struct for the cache of compiled OpenCL programs (context-dependent) @@ -90,9 +93,9 @@ extern template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const; class Database; // The key struct for the cache of database maps. -// Order of fields: precision, device_name, routines (smaller fields first) -typedef std::tuple<Precision, std::string, std::vector<std::string>> DatabaseKey; -typedef std::tuple<const Precision &, const std::string &, const std::vector<std::string> &> DatabaseKeyRef; +// Order of fields: precision, device_name, kernel_name (smaller fields first) +typedef std::tuple<Precision, std::string, std::string> DatabaseKey; +typedef std::tuple<const Precision &, const std::string &, const std::string &> DatabaseKeyRef; typedef Cache<DatabaseKey, Database> DatabaseCache; diff --git a/src/clblast.cpp b/src/clblast.cpp index 35f3f552..a8e4d084 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2254,4 +2254,47 @@ StatusCode FillCache(const cl_device_id device) { } // ================================================================================================= + +// Overrides the tuning parameters for this device-precision-kernel combination +StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name, + const Precision precision, + const std::unordered_map<std::string,size_t> ¶meters) { + try { + + // Retrieves the device name + const auto device_cpp = Device(device); + const auto device_name = device_cpp.Name(); + + // Retrieves the current database values to verify whether the new ones are complete + auto in_cache = false; + const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision, device_name, kernel_name }, &in_cache); + if (!in_cache) { return StatusCode::kInvalidOverrideKernel; } + for (const auto ¤t_param : current_database.GetParameterNames()) { + if (parameters.find(current_param) == parameters.end()) { + return StatusCode::kMissingOverrideParameter; + } + } + + // Clears the existing program & binary cache for routines with the target kernel + const auto routine_names = Routine::routines_by_kernel.at(kernel_name); + for (const auto &routine_name : routine_names) { + ProgramCache::Instance().RemoveBySubset<1, 2>(ProgramKey{nullptr, precision, routine_name}); + BinaryCache::Instance().Remove(BinaryKey{precision, routine_name, device_name}); + } + + // Creates a small custom database based on the provided parameters + const auto database_device = Database::DatabaseDevice{"default", parameters}; + const auto database_vendor = Database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_device}}; + const auto database_entry = Database::DatabaseEntry{kernel_name, precision, {database_vendor}}; + const auto database = Database(device_cpp, kernel_name, precision, {&database_entry}); + + // Removes the old database entry and stores the new one in the cache + DatabaseCache::Instance().Remove(DatabaseKey{ precision, device_name, kernel_name }); + DatabaseCache::Instance().Store(DatabaseKey{ precision, device_name, kernel_name }, Database(database)); + + } catch (...) { return DispatchException(); } + return StatusCode::kSuccess; +} + +// ================================================================================================= } // namespace clblast diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index e4f2b3ed..de431fa4 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -12,6 +12,7 @@ // ================================================================================================= #include <string> +#include <unordered_map> #include "utilities/utilities.hpp" #include "clblast_c.h" @@ -3484,3 +3485,23 @@ CLBlastStatusCode CLBlastFillCache(const cl_device_id device) { } // ================================================================================================= + +// Overrides the tuning parameters for this device-precision-kernel combination +CLBlastStatusCode PUBLIC_API CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name, + const CLBlastPrecision precision, const size_t num_parameters, + const char** parameters_names, const size_t* parameters_values) { + try { + const auto kernel_name_cpp = std::string(kernel_name); + const auto precision_cpp = static_cast<clblast::Precision>(precision); + auto parameters = std::unordered_map<std::string, size_t>(); + for (auto i = size_t{0}; i < num_parameters; ++i) { + const auto parameter_name = std::string(parameters_names[i]); + const auto parameter_value = parameters_values[i]; + parameters[parameter_name] = parameter_value; + } + const auto status = clblast::OverrideParameters(device, kernel_name_cpp, precision_cpp, parameters); + return static_cast<CLBlastStatusCode>(status); + } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } +} + +// ================================================================================================= diff --git a/src/database/database.cpp b/src/database/database.cpp index c1cb9d56..02d0b139 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -63,7 +63,7 @@ const std::unordered_map<std::string, std::string> Database::kVendorNames{ // Constructor, computing device properties and populating the parameter-vector from the database. // This takes an optional overlay database in case of custom tuning or custom kernels. -Database::Database(const Device &device, const std::vector<std::string> &kernels, +Database::Database(const Device &device, const std::string &kernel_name, const Precision precision, const std::vector<const DatabaseEntry*> &overlay): parameters_(std::make_shared<Parameters>()) { @@ -79,20 +79,17 @@ Database::Database(const Device &device, const std::vector<std::string> &kernels } } - // Iterates over all kernels to include, and retrieves the parameters for each of them - for (auto &kernel: kernels) { - auto search_result = ParametersPtr{}; - - for (auto &db: { database, overlay}) { - search_result = Search(kernel, device_type, device_vendor, device_name, precision, db); - if (search_result) { - parameters_->insert(search_result->begin(), search_result->end()); - break; - } + // Searches potentially multiple databases + auto search_result = ParametersPtr{}; + for (auto &db: { overlay, database}) { + search_result = Search(kernel_name, device_type, device_vendor, device_name, precision, db); + if (search_result) { + parameters_->insert(search_result->begin(), search_result->end()); + break; } - - if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); } } + + if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); } } // ================================================================================================= @@ -106,6 +103,15 @@ std::string Database::GetDefines() const { return defines; } +// Retrieves the names of all the parameters +std::vector<std::string> Database::GetParameterNames() const { + auto parameter_names = std::vector<std::string>(); + for (auto ¶meter: *parameters_) { + parameter_names.push_back(parameter.first); + } + return parameter_names; +} + // ================================================================================================= // Searches a particular database for the right kernel and precision diff --git a/src/database/database.hpp b/src/database/database.hpp index 87c12293..b34e0d8a 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -75,15 +75,19 @@ class Database { Database() = default; // The constructor with a user-provided database overlay (potentially an empty vector) - explicit Database(const Device &device, const std::vector<std::string> &routines, + explicit Database(const Device &device, const std::string &kernel_name, const Precision precision, const std::vector<const DatabaseEntry*> &overlay); // Accessor of values by key - size_t operator[](const std::string key) const { return parameters_->find(key)->second; } + size_t operator[](const std::string &key) const { return parameters_->find(key)->second; } + bool exists(const std::string &key) const { return (parameters_->count(key) == 1); } // Obtain a list of OpenCL pre-processor defines based on the parameters std::string GetDefines() const; + // Retrieves the names of all the parameters + std::vector<std::string> GetParameterNames() const; + private: // Search method for a specified database, returning pointer (possibly a nullptr) ParametersPtr Search(const std::string &this_kernel, const std::string &this_type, @@ -96,6 +100,31 @@ class Database { }; // ================================================================================================= + +// Multiple databases together in a map +class Databases { + public: + + explicit Databases(const std::vector<std::string> &kernel_names): kernel_names_(kernel_names) { } + + // Database accessor + Database& operator()(const std::string &kernel_name) { return databases_[kernel_name]; } + + // Retrieves a parameter from the database + size_t operator[](const std::string &key) const { + for (const auto &kernel_name : kernel_names_) { + const auto &kernel_db = databases_.find(kernel_name)->second; + if (kernel_db.exists(key)) { return kernel_db[key]; } + } + throw RuntimeErrorCode(StatusCode::kDatabaseError); + } + + private: + const std::vector<std::string> kernel_names_; + std::unordered_map<std::string, Database> databases_; +}; + +// ================================================================================================= } // namespace clblast // CLBLAST_DATABASE_H_ diff --git a/src/routine.cpp b/src/routine.cpp index 4fe04a60..3cd045c8 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -21,36 +21,64 @@ namespace clblast { // ================================================================================================= +// For each kernel this map contains a list of routines it is used in +const std::vector<std::string> Routine::routines_axpy = {"AXPY", "COPY", "SCAL", "SWAP"}; +const std::vector<std::string> Routine::routines_dot = {"AMAX", "ASUM", "DOT", "DOTC", "DOTU", "MAX", "MIN", "NRM2", "SUM"}; +const std::vector<std::string> Routine::routines_ger = {"GER", "GERC", "GERU", "HER", "HER2", "HPR", "HPR2", "SPR", "SPR2", "SYR", "SYR2"}; +const std::vector<std::string> Routine::routines_gemv = {"GBMV", "GEMV", "HBMV", "HEMV", "HPMV", "SBMV", "SPMV", "SYMV", "TMBV", "TPMV", "TRMV"}; +const std::vector<std::string> Routine::routines_gemm = {"GEMM", "HEMM", "SYMM", "TRMM"}; +const std::vector<std::string> Routine::routines_gemm_syrk = {"GEMM", "HEMM", "HER2K", "HERK", "SYMM", "SYR2K", "SYRK", "TRMM"}; +const std::unordered_map<std::string, const std::vector<std::string>> Routine::routines_by_kernel = { + {"Xaxpy", routines_axpy}, + {"Xdot", routines_dot}, + {"Xgemv", routines_gemv}, + {"XgemvFast", routines_gemv}, + {"XgemvFastRot", routines_gemv}, + {"Xgemv", {}}, + {"Xger", routines_ger}, + {"Copy", routines_gemm_syrk}, + {"Pad", routines_gemm_syrk}, + {"Transpose", routines_gemm_syrk}, + {"Padtranspose", routines_gemm_syrk}, + {"Xgemm", routines_gemm_syrk}, + {"XgemmDirect", routines_gemm}, + {"KernelSelection", routines_gemm}, +}; +// ================================================================================================= + // The constructor does all heavy work, errors are returned as exceptions Routine::Routine(Queue &queue, EventPointer event, const std::string &name, - const std::vector<std::string> &routines, const Precision precision, + const std::vector<std::string> &kernel_names, const Precision precision, const std::vector<const Database::DatabaseEntry*> &userDatabase, std::initializer_list<const char *> source): precision_(precision), routine_name_(name), + kernel_names_(kernel_names), queue_(queue), event_(event), context_(queue_.GetContext()), device_(queue_.GetDevice()), - device_name_(device_.Name()) { + device_name_(device_.Name()), + db_(kernel_names) { - InitDatabase(routines, userDatabase); + InitDatabase(userDatabase); InitProgram(source); } -void Routine::InitDatabase(const std::vector<std::string> &routines, - const std::vector<const Database::DatabaseEntry*> &userDatabase) { +void Routine::InitDatabase(const std::vector<const Database::DatabaseEntry*> &userDatabase) { + for (const auto &kernel_name : kernel_names_) { - // Queries the cache to see whether or not the kernel parameter database is already there - bool has_db; - db_ = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, routines }, - &has_db); - if (has_db) { return; } + // Queries the cache to see whether or not the kernel parameter database is already there + bool has_db; + db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, kernel_name }, + &has_db); + if (has_db) { continue; } - // Builds the parameter database for this device and routine set and stores it in the cache - db_ = Database(device_, routines, precision_, userDatabase); - DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, routines }, - Database{ db_ }); + // Builds the parameter database for this device and routine set and stores it in the cache + db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase); + DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, kernel_name }, + Database{ db_(kernel_name) }); + } } void Routine::InitProgram(std::initializer_list<const char *> source) { @@ -96,7 +124,10 @@ void Routine::InitProgram(std::initializer_list<const char *> source) { } // Collects the parameters for this device in the form of defines, and adds the precision - auto source_string = db_.GetDefines(); + auto source_string = std::string{""}; + for (const auto &kernel_name : kernel_names_) { + source_string += db_(kernel_name).GetDefines(); + } source_string += "#define PRECISION "+ToString(static_cast<int>(precision_))+"\n"; // Adds the name of the routine as a define diff --git a/src/routine.hpp b/src/routine.hpp index f366e4d9..622a1c0d 100644 --- a/src/routine.hpp +++ b/src/routine.hpp @@ -18,6 +18,7 @@ #include <string> #include <vector> +#include <unordered_map> #include "utilities/utilities.hpp" #include "cache.hpp" @@ -42,22 +43,31 @@ class Routine { const std::vector<const Database::DatabaseEntry*> &userDatabase, std::initializer_list<const char *> source); + // List of kernel-routine look-ups + static const std::vector<std::string> routines_axpy; + static const std::vector<std::string> routines_dot; + static const std::vector<std::string> routines_ger; + static const std::vector<std::string> routines_gemv; + static const std::vector<std::string> routines_gemm; + static const std::vector<std::string> routines_gemm_syrk; + static const std::unordered_map<std::string, const std::vector<std::string>> routines_by_kernel; + private: // Initializes program_, fetching cached program or building one void InitProgram(std::initializer_list<const char *> source); // Initializes db_, fetching cached database or building one - void InitDatabase(const std::vector<std::string> &routines, - const std::vector<const Database::DatabaseEntry*> &userDatabase); + void InitDatabase(const std::vector<const Database::DatabaseEntry*> &userDatabase); protected: // Non-static variable for the precision const Precision precision_; - // The routine's name + // The routine's name and the corresponding kernels const std::string routine_name_; + const std::vector<std::string> kernel_names_; // The OpenCL objects, accessible only from derived classes Queue queue_; @@ -72,7 +82,7 @@ class Routine { Program program_; // Connection to the database for all the device-specific parameters - Database db_; + Databases db_; }; // ================================================================================================= diff --git a/src/routines/common.hpp b/src/routines/common.hpp index 7c211c0d..d268d58b 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -37,7 +37,7 @@ void RunKernel(Kernel &kernel, Queue &queue, const Device &device, // to write to symmetric and triangular matrices through optional arguments. template <typename T> void PadCopyTransposeMatrix(Queue &queue, const Device &device, - const Database &db, + const Databases &db, EventPointer event, const std::vector<Event> &waitForEvents, const size_t src_one, const size_t src_two, const size_t src_ld, const size_t src_offset, diff --git a/src/routines/level2/xgemv.cpp b/src/routines/level2/xgemv.cpp index 9e9c2db4..aae66798 100644 --- a/src/routines/level2/xgemv.cpp +++ b/src/routines/level2/xgemv.cpp @@ -22,7 +22,7 @@ namespace clblast { // Constructor: forwards to base class constructor template <typename T> Xgemv<T>::Xgemv(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, {"Pad", "Xgemv", "XgemvFast", "XgemvFastRot"}, PrecisionValue<T>(), {}, { + Routine(queue, event, name, {"Xgemv", "XgemvFast", "XgemvFastRot"}, PrecisionValue<T>(), {}, { #include "../../kernels/level2/xgemv.opencl" #include "../../kernels/level2/xgemv_fast.opencl" }) { |