diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-10-08 21:52:02 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-10-08 21:52:02 +0200 |
commit | df3c9f4a8ab9e82ccc4add15b04da5c1b6172b72 (patch) | |
tree | 834d7f481d0370ce5e467e940396ee7c3738eb3c | |
parent | 2bb8402ec15a672eaa26595247aa09f7d88fecdb (diff) |
Moved non-routine-specific API functions and includes to separate files
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rwxr-xr-x | scripts/generator/generator.py | 4 | ||||
-rw-r--r-- | src/api_common.cpp | 169 | ||||
-rw-r--r-- | src/clblast.cpp | 207 | ||||
-rw-r--r-- | src/routines/routines.hpp | 76 |
5 files changed, 250 insertions, 208 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 05e7393b..52accbd4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -182,6 +182,7 @@ set(SOURCES src/routines/common.cpp src/utilities/clblast_exceptions.cpp src/utilities/utilities.cpp + src/api_common.cpp src/cache.cpp src/clblast.cpp src/clblast_c.cpp @@ -201,6 +202,7 @@ set(HEADERS # such that they can be discovered by IDEs such as CLion and Visual src/routines/level1/xmin.hpp src/routines/level1/xsum.hpp src/routines/common.hpp + src/routines/routines.hpp src/utilities/buffer_test.hpp src/utilities/clblast_exceptions.hpp src/utilities/device_mapping.hpp diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index df0eaca0..0d34d7fe 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -42,8 +42,8 @@ FILES = [ "/include/clblast_netlib_c.h", "/src/clblast_netlib_c.cpp", ] -HEADER_LINES = [122, 79, 126, 24, 29, 41, 29, 65, 32] -FOOTER_LINES = [25, 147, 27, 38, 6, 6, 6, 9, 2] +HEADER_LINES = [122, 21, 126, 24, 29, 41, 29, 65, 32] +FOOTER_LINES = [25, 3, 27, 38, 6, 6, 6, 9, 2] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 63 diff --git a/src/api_common.cpp b/src/api_common.cpp new file mode 100644 index 00000000..aa7e2b0f --- /dev/null +++ b/src/api_common.cpp @@ -0,0 +1,169 @@ +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This file implements the common (non-OpenCL-specific) functions of the CLBlast API. +// +// ================================================================================================= + +#include <string> + +#include "cache.hpp" +#include "routines/routines.hpp" +#include "clblast.h" + +namespace clblast { +// ================================================================================================= + +// Clears the cache of stored binaries +StatusCode ClearCache() { + try { + ProgramCache::Instance().Invalidate(); + BinaryCache::Instance().Invalidate(); + } catch (...) { return DispatchException(); } + return StatusCode::kSuccess; +} + +template <typename Real, typename Complex> +void FillCacheForPrecision(Queue &queue) { + try { + + // Runs all the level 1 set-up functions + Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr); + Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr); + Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr); + Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr); + Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr); + Xdot<Real>(queue, nullptr); + Xdotu<Complex>(queue, nullptr); + Xdotc<Complex>(queue, nullptr); + Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr); + Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr); + Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr); + Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr); + Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr); + Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr); + + // Runs all the level 2 set-up functions + Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr); + Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr); + Xhemv<Complex>(queue, nullptr); + Xhbmv<Complex>(queue, nullptr); + Xhpmv<Complex>(queue, nullptr); + Xsymv<Real>(queue, nullptr); + Xsbmv<Real>(queue, nullptr); + Xspmv<Real>(queue, nullptr); + Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr); + Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr); + Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr); + Xger<Real>(queue, nullptr); + Xgeru<Complex>(queue, nullptr); + Xgerc<Complex>(queue, nullptr); + Xher<Complex,Real>(queue, nullptr); + Xhpr<Complex,Real>(queue, nullptr); + Xher2<Complex>(queue, nullptr); + Xhpr2<Complex>(queue, nullptr); + Xsyr<Real>(queue, nullptr); + Xspr<Real>(queue, nullptr); + Xsyr2<Real>(queue, nullptr); + Xspr2<Real>(queue, nullptr); + + // Runs all the level 3 set-up functions + Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr); + Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr); + Xhemm<Complex>(queue, nullptr); + Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr); + Xherk<Complex,Real>(queue, nullptr); + Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr); + Xher2k<Complex,Real>(queue, nullptr); + Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr); + + // Runs all the non-BLAS set-up functions + Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr); + + } catch(const RuntimeErrorCode &e) { + if (e.status() != StatusCode::kNoDoublePrecision && + e.status() != StatusCode::kNoHalfPrecision) { + throw; + } + } +} + +// Fills the cache with all binaries for a specific device +// TODO: Add half-precision FP16 set-up calls +StatusCode FillCache(const RawDeviceID device) { + try { + + // Creates a sample context and queue to match the normal routine calling conventions + auto device_cpp = Device(device); + auto context = Context(device_cpp); + auto queue = Queue(context, device_cpp); + + FillCacheForPrecision<float, float2>(queue); + FillCacheForPrecision<double, double2>(queue); + + } catch (...) { return DispatchException(); } + return StatusCode::kSuccess; +} + +// ================================================================================================= + +// Overrides the tuning parameters for this device-precision-kernel combination +StatusCode OverrideParameters(const RawDeviceID device, const std::string &kernel_name, + const Precision precision, + const std::unordered_map<std::string,size_t> ¶meters) { + try { + + // Retrieves the device name + const auto device_cpp = Device(device); + const auto platform_id = device_cpp.PlatformID(); + const auto device_name = GetDeviceName(device_cpp); + + // Retrieves the current database values to verify whether the new ones are complete + auto in_cache = false; + auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache); + if (!in_cache) { + log_debug("Searching database for kernel '" + kernel_name + "'"); + current_database = Database(device_cpp, kernel_name, precision, {}); + } + + // Verifies the parameters size + const auto current_parameter_names = current_database.GetParameterNames(); + if (current_parameter_names.size() != parameters.size()) { + return StatusCode::kMissingOverrideParameter; + } + + // Retrieves the names and values separately and in the same order as the existing database + auto parameter_values = database::Params{0}; + auto i = size_t{0}; + for (const auto ¤t_param : current_parameter_names) { + if (parameters.find(current_param) == parameters.end()) { + return StatusCode::kMissingOverrideParameter; + } + const auto parameter_value = parameters.at(current_param); + parameter_values[i] = parameter_value; + ++i; + } + + // Creates a small custom database based on the provided parameters + const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values}; + const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}}; + const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}}; + const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}}; + const auto database_entries = std::vector<database::DatabaseEntry>{database_entry}; + const auto database = Database(device_cpp, kernel_name, precision, database_entries); + + // Removes the old database entry and stores the new one in the cache + DatabaseCache::Instance().Remove(DatabaseKey{platform_id, device, precision, kernel_name}); + DatabaseCache::Instance().Store(DatabaseKey{platform_id, device, precision, kernel_name}, Database(database)); + + } catch (...) { return DispatchException(); } + return StatusCode::kSuccess; +} + +// ================================================================================================= +} // namespace clblast diff --git a/src/clblast.cpp b/src/clblast.cpp index 9f865a23..7d2c2cef 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -15,67 +15,9 @@ #include <string> -#include "cache.hpp" +#include "routines/routines.hpp" #include "clblast.h" -// BLAS level-1 includes -#include "routines/level1/xswap.hpp" -#include "routines/level1/xscal.hpp" -#include "routines/level1/xcopy.hpp" -#include "routines/level1/xaxpy.hpp" -#include "routines/level1/xdot.hpp" -#include "routines/level1/xdotu.hpp" -#include "routines/level1/xdotc.hpp" -#include "routines/level1/xnrm2.hpp" -#include "routines/level1/xasum.hpp" -#include "routines/level1/xsum.hpp" // non-BLAS routine -#include "routines/level1/xamax.hpp" -#include "routines/level1/xamin.hpp" // non-BLAS routine -#include "routines/level1/xmax.hpp" // non-BLAS routine -#include "routines/level1/xmin.hpp" // non-BLAS routine - -// BLAS level-2 includes -#include "routines/level2/xgemv.hpp" -#include "routines/level2/xgbmv.hpp" -#include "routines/level2/xhemv.hpp" -#include "routines/level2/xhbmv.hpp" -#include "routines/level2/xhpmv.hpp" -#include "routines/level2/xsymv.hpp" -#include "routines/level2/xsbmv.hpp" -#include "routines/level2/xspmv.hpp" -#include "routines/level2/xtrmv.hpp" -#include "routines/level2/xtbmv.hpp" -#include "routines/level2/xtpmv.hpp" -#include "routines/level2/xtrsv.hpp" -#include "routines/level2/xger.hpp" -#include "routines/level2/xgeru.hpp" -#include "routines/level2/xgerc.hpp" -#include "routines/level2/xher.hpp" -#include "routines/level2/xhpr.hpp" -#include "routines/level2/xher2.hpp" -#include "routines/level2/xhpr2.hpp" -#include "routines/level2/xsyr.hpp" -#include "routines/level2/xspr.hpp" -#include "routines/level2/xsyr2.hpp" -#include "routines/level2/xspr2.hpp" - -// BLAS level-3 includes -#include "routines/level3/xgemm.hpp" -#include "routines/level3/xsymm.hpp" -#include "routines/level3/xhemm.hpp" -#include "routines/level3/xsyrk.hpp" -#include "routines/level3/xherk.hpp" -#include "routines/level3/xsyr2k.hpp" -#include "routines/level3/xher2k.hpp" -#include "routines/level3/xtrmm.hpp" -#include "routines/level3/xtrsm.hpp" - -// Level-x includes (non-BLAS) -#include "routines/levelx/xomatcopy.hpp" -#include "routines/levelx/xim2col.hpp" -#include "routines/levelx/xaxpybatched.hpp" -#include "routines/levelx/xgemmbatched.hpp" - namespace clblast { // ================================================================================================= @@ -2389,153 +2331,6 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, cl_mem, const size_t*, const size_t, const size_t, cl_command_queue*, cl_event*); -// ================================================================================================= - -// Clears the cache of stored binaries -StatusCode ClearCache() { - try { - ProgramCache::Instance().Invalidate(); - BinaryCache::Instance().Invalidate(); - } catch (...) { return DispatchException(); } - return StatusCode::kSuccess; -} - -template <typename Real, typename Complex> -void FillCacheForPrecision(Queue &queue) { - try { - - // Runs all the level 1 set-up functions - Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr); - Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr); - Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr); - Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr); - Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr); - Xdot<Real>(queue, nullptr); - Xdotu<Complex>(queue, nullptr); - Xdotc<Complex>(queue, nullptr); - Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr); - Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr); - Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr); - Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr); - Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr); - Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr); - - // Runs all the level 2 set-up functions - Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr); - Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr); - Xhemv<Complex>(queue, nullptr); - Xhbmv<Complex>(queue, nullptr); - Xhpmv<Complex>(queue, nullptr); - Xsymv<Real>(queue, nullptr); - Xsbmv<Real>(queue, nullptr); - Xspmv<Real>(queue, nullptr); - Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr); - Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr); - Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr); - Xger<Real>(queue, nullptr); - Xgeru<Complex>(queue, nullptr); - Xgerc<Complex>(queue, nullptr); - Xher<Complex,Real>(queue, nullptr); - Xhpr<Complex,Real>(queue, nullptr); - Xher2<Complex>(queue, nullptr); - Xhpr2<Complex>(queue, nullptr); - Xsyr<Real>(queue, nullptr); - Xspr<Real>(queue, nullptr); - Xsyr2<Real>(queue, nullptr); - Xspr2<Real>(queue, nullptr); - - // Runs all the level 3 set-up functions - Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr); - Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr); - Xhemm<Complex>(queue, nullptr); - Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr); - Xherk<Complex,Real>(queue, nullptr); - Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr); - Xher2k<Complex,Real>(queue, nullptr); - Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr); - - // Runs all the non-BLAS set-up functions - Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr); - - } catch(const RuntimeErrorCode &e) { - if (e.status() != StatusCode::kNoDoublePrecision && - e.status() != StatusCode::kNoHalfPrecision) { - throw; - } - } -} - -// Fills the cache with all binaries for a specific device -// TODO: Add half-precision FP16 set-up calls -StatusCode FillCache(const cl_device_id device) { - try { - - // Creates a sample context and queue to match the normal routine calling conventions - auto device_cpp = Device(device); - auto context = Context(device_cpp); - auto queue = Queue(context, device_cpp); - - FillCacheForPrecision<float, float2>(queue); - FillCacheForPrecision<double, double2>(queue); - - } catch (...) { return DispatchException(); } - return StatusCode::kSuccess; -} - -// ================================================================================================= - -// Overrides the tuning parameters for this device-precision-kernel combination -StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name, - const Precision precision, - const std::unordered_map<std::string,size_t> ¶meters) { - try { - - // Retrieves the device name - const auto device_cpp = Device(device); - const auto platform_id = device_cpp.PlatformID(); - const auto device_name = GetDeviceName(device_cpp); - - // Retrieves the current database values to verify whether the new ones are complete - auto in_cache = false; - auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache); - if (!in_cache) { - log_debug("Searching database for kernel '" + kernel_name + "'"); - current_database = Database(device_cpp, kernel_name, precision, {}); - } - - // Verifies the parameters size - const auto current_parameter_names = current_database.GetParameterNames(); - if (current_parameter_names.size() != parameters.size()) { - return StatusCode::kMissingOverrideParameter; - } - - // Retrieves the names and values separately and in the same order as the existing database - auto parameter_values = database::Params{0}; - auto i = size_t{0}; - for (const auto ¤t_param : current_parameter_names) { - if (parameters.find(current_param) == parameters.end()) { - return StatusCode::kMissingOverrideParameter; - } - const auto parameter_value = parameters.at(current_param); - parameter_values[i] = parameter_value; - ++i; - } - - // Creates a small custom database based on the provided parameters - const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values}; - const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}}; - const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}}; - const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}}; - const auto database_entries = std::vector<database::DatabaseEntry>{database_entry}; - const auto database = Database(device_cpp, kernel_name, precision, database_entries); - - // Removes the old database entry and stores the new one in the cache - DatabaseCache::Instance().Remove(DatabaseKey{platform_id, device, precision, kernel_name}); - DatabaseCache::Instance().Store(DatabaseKey{platform_id, device, precision, kernel_name}, Database(database)); - - } catch (...) { return DispatchException(); } - return StatusCode::kSuccess; -} // ================================================================================================= } // namespace clblast diff --git a/src/routines/routines.hpp b/src/routines/routines.hpp new file mode 100644 index 00000000..9e7768b9 --- /dev/null +++ b/src/routines/routines.hpp @@ -0,0 +1,76 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This file contains all the includes of all the routines in CLBlast. +// +// ================================================================================================= + +#ifndef CLBLAST_ROUTINES_ROUTINES_H_ +#define CLBLAST_ROUTINES_ROUTINES_H_ + +// BLAS level-1 includes +#include "routines/level1/xswap.hpp" +#include "routines/level1/xscal.hpp" +#include "routines/level1/xcopy.hpp" +#include "routines/level1/xaxpy.hpp" +#include "routines/level1/xdot.hpp" +#include "routines/level1/xdotu.hpp" +#include "routines/level1/xdotc.hpp" +#include "routines/level1/xnrm2.hpp" +#include "routines/level1/xasum.hpp" +#include "routines/level1/xsum.hpp" // non-BLAS routine +#include "routines/level1/xamax.hpp" +#include "routines/level1/xamin.hpp" // non-BLAS routine +#include "routines/level1/xmax.hpp" // non-BLAS routine +#include "routines/level1/xmin.hpp" // non-BLAS routine + +// BLAS level-2 includes +#include "routines/level2/xgemv.hpp" +#include "routines/level2/xgbmv.hpp" +#include "routines/level2/xhemv.hpp" +#include "routines/level2/xhbmv.hpp" +#include "routines/level2/xhpmv.hpp" +#include "routines/level2/xsymv.hpp" +#include "routines/level2/xsbmv.hpp" +#include "routines/level2/xspmv.hpp" +#include "routines/level2/xtrmv.hpp" +#include "routines/level2/xtbmv.hpp" +#include "routines/level2/xtpmv.hpp" +#include "routines/level2/xtrsv.hpp" +#include "routines/level2/xger.hpp" +#include "routines/level2/xgeru.hpp" +#include "routines/level2/xgerc.hpp" +#include "routines/level2/xher.hpp" +#include "routines/level2/xhpr.hpp" +#include "routines/level2/xher2.hpp" +#include "routines/level2/xhpr2.hpp" +#include "routines/level2/xsyr.hpp" +#include "routines/level2/xspr.hpp" +#include "routines/level2/xsyr2.hpp" +#include "routines/level2/xspr2.hpp" + +// BLAS level-3 includes +#include "routines/level3/xgemm.hpp" +#include "routines/level3/xsymm.hpp" +#include "routines/level3/xhemm.hpp" +#include "routines/level3/xsyrk.hpp" +#include "routines/level3/xherk.hpp" +#include "routines/level3/xsyr2k.hpp" +#include "routines/level3/xher2k.hpp" +#include "routines/level3/xtrmm.hpp" +#include "routines/level3/xtrsm.hpp" + +// Level-x includes (non-BLAS) +#include "routines/levelx/xomatcopy.hpp" +#include "routines/levelx/xim2col.hpp" +#include "routines/levelx/xaxpybatched.hpp" +#include "routines/levelx/xgemmbatched.hpp" + +// CLBLAST_ROUTINES_ROUTINES_H_ +#endif |