diff options
-rw-r--r-- | include/clblast.h | 8 | ||||
-rwxr-xr-x | scripts/generator/generator.py | 4 | ||||
-rw-r--r-- | src/clblast.cpp | 32 | ||||
-rw-r--r-- | src/routine.cpp | 25 | ||||
-rw-r--r-- | src/routine.hpp | 10 |
5 files changed, 77 insertions, 2 deletions
diff --git a/include/clblast.h b/include/clblast.h index 7b2021d8..e7b53d65 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -17,6 +17,8 @@ #define CLBLAST_CLBLAST_H_ #include <cstdlib> // For size_t +#include <string> // For OverrideParameters function +#include <unordered_map> // For OverrideParameters function // Includes the normal OpenCL C header #if defined(__APPLE__) || defined(__MACOSX) @@ -617,6 +619,12 @@ StatusCode PUBLIC_API FillCache(const cl_device_id device); // ================================================================================================= +StatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const std::string &kernel_name, + const Precision precision, + const std::unordered_map<std::string,size_t> ¶meters); + +// ================================================================================================= + } // namespace clblast // CLBLAST_CLBLAST_H_ diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 6591cbf7..aaf1b121 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -41,8 +41,8 @@ FILES = [ "/include/clblast_netlib_c.h", "/src/clblast_netlib_c.cpp", ] -HEADER_LINES = [117, 73, 118, 22, 29, 41, 65, 32] -FOOTER_LINES = [17, 95, 19, 18, 6, 6, 9, 2] +HEADER_LINES = [119, 73, 118, 22, 29, 41, 65, 32] +FOOTER_LINES = [23, 128, 19, 18, 6, 6, 9, 2] # Different possibilities for requirements ald_m = "The value of `a_ld` must be at least `m`." diff --git a/src/clblast.cpp b/src/clblast.cpp index 35f3f552..885b849e 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2254,4 +2254,36 @@ StatusCode FillCache(const cl_device_id device) { } // ================================================================================================= + +StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name, + const Precision precision, + const std::unordered_map<std::string,size_t> ¶meters) { + try { + + // Retrieves the device name + const auto device_cpp = Device(device); + const auto device_name = device_cpp.Name(); + + // Clears the existing program & binary cache for routines with the target kernel + const auto routine_names = Routine::routines_by_kernel.at(kernel_name); + for (const auto &routine_name : routine_names) { + ProgramCache::Instance().RemoveBySubset<1, 2>(ProgramKey{nullptr, precision, routine_name}); + BinaryCache::Instance().Remove(BinaryKey{precision, routine_name, device_name}); + } + + // Creates a small custom database based on the provided parameters + const auto database_device = Database::DatabaseDevice{"default", parameters}; + const auto database_vendor = Database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_device}}; + const auto database_entry = Database::DatabaseEntry{kernel_name, precision, {database_vendor}}; + const auto database = Database(device_cpp, kernel_name, precision, {&database_entry}); + + // Removes the old database entry and stores the new one in the cache + DatabaseCache::Instance().Remove(DatabaseKey{ precision, device_name, kernel_name }); + DatabaseCache::Instance().Store(DatabaseKey{ precision, device_name, kernel_name }, Database(database)); + + } catch (...) { return DispatchException(); } + return StatusCode::kSuccess; +} + +// ================================================================================================= } // namespace clblast diff --git a/src/routine.cpp b/src/routine.cpp index 854c7046..3cd045c8 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -21,6 +21,31 @@ namespace clblast { // ================================================================================================= +// For each kernel this map contains a list of routines it is used in +const std::vector<std::string> Routine::routines_axpy = {"AXPY", "COPY", "SCAL", "SWAP"}; +const std::vector<std::string> Routine::routines_dot = {"AMAX", "ASUM", "DOT", "DOTC", "DOTU", "MAX", "MIN", "NRM2", "SUM"}; +const std::vector<std::string> Routine::routines_ger = {"GER", "GERC", "GERU", "HER", "HER2", "HPR", "HPR2", "SPR", "SPR2", "SYR", "SYR2"}; +const std::vector<std::string> Routine::routines_gemv = {"GBMV", "GEMV", "HBMV", "HEMV", "HPMV", "SBMV", "SPMV", "SYMV", "TMBV", "TPMV", "TRMV"}; +const std::vector<std::string> Routine::routines_gemm = {"GEMM", "HEMM", "SYMM", "TRMM"}; +const std::vector<std::string> Routine::routines_gemm_syrk = {"GEMM", "HEMM", "HER2K", "HERK", "SYMM", "SYR2K", "SYRK", "TRMM"}; +const std::unordered_map<std::string, const std::vector<std::string>> Routine::routines_by_kernel = { + {"Xaxpy", routines_axpy}, + {"Xdot", routines_dot}, + {"Xgemv", routines_gemv}, + {"XgemvFast", routines_gemv}, + {"XgemvFastRot", routines_gemv}, + {"Xgemv", {}}, + {"Xger", routines_ger}, + {"Copy", routines_gemm_syrk}, + {"Pad", routines_gemm_syrk}, + {"Transpose", routines_gemm_syrk}, + {"Padtranspose", routines_gemm_syrk}, + {"Xgemm", routines_gemm_syrk}, + {"XgemmDirect", routines_gemm}, + {"KernelSelection", routines_gemm}, +}; +// ================================================================================================= + // The constructor does all heavy work, errors are returned as exceptions Routine::Routine(Queue &queue, EventPointer event, const std::string &name, const std::vector<std::string> &kernel_names, const Precision precision, diff --git a/src/routine.hpp b/src/routine.hpp index ba8b9f60..622a1c0d 100644 --- a/src/routine.hpp +++ b/src/routine.hpp @@ -18,6 +18,7 @@ #include <string> #include <vector> +#include <unordered_map> #include "utilities/utilities.hpp" #include "cache.hpp" @@ -42,6 +43,15 @@ class Routine { const std::vector<const Database::DatabaseEntry*> &userDatabase, std::initializer_list<const char *> source); + // List of kernel-routine look-ups + static const std::vector<std::string> routines_axpy; + static const std::vector<std::string> routines_dot; + static const std::vector<std::string> routines_ger; + static const std::vector<std::string> routines_gemv; + static const std::vector<std::string> routines_gemm; + static const std::vector<std::string> routines_gemm_syrk; + static const std::unordered_map<std::string, const std::vector<std::string>> routines_by_kernel; + private: // Initializes program_, fetching cached program or building one |