// ================================================================================================= // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- // width of 100 characters per line. // // Author(s): // Cedric Nugteren // // This file implements the common (non-OpenCL-specific) functions of the CLBlast API. // // ================================================================================================= #include #include "utilities/utilities.hpp" #include "cache.hpp" #include "routines/routines.hpp" namespace clblast { // ================================================================================================= // Clears the cache of stored binaries StatusCode ClearCache() { try { ProgramCache::Instance().Invalidate(); BinaryCache::Instance().Invalidate(); } catch (...) { return DispatchException(); } return StatusCode::kSuccess; } template void FillCacheForPrecision(Queue &queue) { try { // Runs all the level 1 set-up functions Xswap(queue, nullptr); Xswap(queue, nullptr); Xswap(queue, nullptr); Xswap(queue, nullptr); Xscal(queue, nullptr); Xscal(queue, nullptr); Xcopy(queue, nullptr); Xcopy(queue, nullptr); Xaxpy(queue, nullptr); Xaxpy(queue, nullptr); Xdot(queue, nullptr); Xdotu(queue, nullptr); Xdotc(queue, nullptr); Xnrm2(queue, nullptr); Xnrm2(queue, nullptr); Xasum(queue, nullptr); Xasum(queue, nullptr); Xsum(queue, nullptr); Xsum(queue, nullptr); Xamax(queue, nullptr); Xamax(queue, nullptr); Xmax(queue, nullptr); Xmax(queue, nullptr); Xmin(queue, nullptr); Xmin(queue, nullptr); // Runs all the level 2 set-up functions Xgemv(queue, nullptr); Xgemv(queue, nullptr); Xgbmv(queue, nullptr); Xgbmv(queue, nullptr); Xhemv(queue, nullptr); Xhbmv(queue, nullptr); Xhpmv(queue, nullptr); Xsymv(queue, nullptr); Xsbmv(queue, nullptr); Xspmv(queue, nullptr); Xtrmv(queue, nullptr); Xtrmv(queue, nullptr); Xtbmv(queue, nullptr); Xtbmv(queue, nullptr); Xtpmv(queue, nullptr); Xtpmv(queue, nullptr); Xger(queue, nullptr); Xgeru(queue, nullptr); Xgerc(queue, nullptr); Xher(queue, nullptr); Xhpr(queue, nullptr); Xher2(queue, nullptr); Xhpr2(queue, nullptr); Xsyr(queue, nullptr); Xspr(queue, nullptr); Xsyr2(queue, nullptr); Xspr2(queue, nullptr); // Runs all the level 3 set-up functions Xgemm(queue, nullptr); Xgemm(queue, nullptr); Xsymm(queue, nullptr); Xsymm(queue, nullptr); Xhemm(queue, nullptr); Xsyrk(queue, nullptr); Xsyrk(queue, nullptr); Xherk(queue, nullptr); Xsyr2k(queue, nullptr); Xsyr2k(queue, nullptr); Xher2k(queue, nullptr); Xtrmm(queue, nullptr); Xtrmm(queue, nullptr); // Runs all the non-BLAS set-up functions Xomatcopy(queue, nullptr); Xomatcopy(queue, nullptr); } catch(const RuntimeErrorCode &e) { if (e.status() != StatusCode::kNoDoublePrecision && e.status() != StatusCode::kNoHalfPrecision) { throw; } } } // Fills the cache with all binaries for a specific device // TODO: Add half-precision FP16 set-up calls StatusCode FillCache(const RawDeviceID device) { try { // Creates a sample context and queue to match the normal routine calling conventions auto device_cpp = Device(device); auto context = Context(device_cpp); auto queue = Queue(context, device_cpp); FillCacheForPrecision(queue); FillCacheForPrecision(queue); } catch (...) { return DispatchException(); } return StatusCode::kSuccess; } // ================================================================================================= // Retrieves the current tuning parameters for this device-precision-kernel combination StatusCode RetrieveParameters(const RawDeviceID device, const std::string &kernel_name, const Precision precision, std::unordered_map ¶meters) { try { // Retrieves the device name const auto device_cpp = Device(device); const auto platform_id = device_cpp.PlatformID(); const auto device_name = GetDeviceName(device_cpp); // Retrieves the database values auto in_cache = false; auto database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache); if (!in_cache) { log_debug("Searching database for kernel '" + kernel_name + "'"); database = Database(device_cpp, kernel_name, precision, {}); } // Retrieves the parameters for (const auto ¶meter: database.GetParameters()) { parameters[parameter.first] = parameter.second; } } catch (...) { return DispatchException(); } return StatusCode::kSuccess; } // Overrides the tuning parameters for this device-precision-kernel combination StatusCode OverrideParameters(const RawDeviceID device, const std::string &kernel_name, const Precision precision, const std::unordered_map ¶meters) { try { // Retrieves the device name const auto device_cpp = Device(device); const auto platform_id = device_cpp.PlatformID(); const auto device_name = GetDeviceName(device_cpp); // Retrieves the current database values to verify whether the new ones are complete auto in_cache = false; auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache); if (!in_cache) { log_debug("Searching database for kernel '" + kernel_name + "'"); current_database = Database(device_cpp, kernel_name, precision, {}); } // Verifies the parameters size const auto current_parameter_names = current_database.GetParameterNames(); if (current_parameter_names.size() > parameters.size()) { return StatusCode::kMissingOverrideParameter; } // Retrieves the names and values separately and in the same order as the existing database auto parameter_values = database::Params{0}; auto i = size_t{0}; for (const auto ¤t_param : current_parameter_names) { if (parameters.find(current_param) == parameters.end()) { return StatusCode::kMissingOverrideParameter; } const auto parameter_value = parameters.at(current_param); parameter_values[i] = parameter_value; ++i; } // Creates a small custom database based on the provided parameters const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values}; const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}}; const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}}; const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}}; const auto database_entries = std::vector{database_entry}; const auto database = Database(device_cpp, kernel_name, precision, database_entries); // Removes the old database entry and stores the new one in the cache DatabaseCache::Instance().Remove(DatabaseKey{platform_id, device, precision, kernel_name}); DatabaseCache::Instance().Store(DatabaseKey{platform_id, device, precision, kernel_name}, Database(database)); } catch (...) { return DispatchException(); } return StatusCode::kSuccess; } // ================================================================================================= } // namespace clblast