diff options
Diffstat (limited to 'src/routine.cpp')
-rw-r--r-- | src/routine.cpp | 63 |
1 files changed, 48 insertions, 15 deletions
diff --git a/src/routine.cpp b/src/routine.cpp index 4fe04a60..b5823bc9 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -21,36 +21,66 @@ namespace clblast { // ================================================================================================= +// For each kernel this map contains a list of routines it is used in +const std::vector<std::string> Routine::routines_axpy = {"AXPY", "COPY", "SCAL", "SWAP"}; +const std::vector<std::string> Routine::routines_dot = {"AMAX", "ASUM", "DOT", "DOTC", "DOTU", "MAX", "MIN", "NRM2", "SUM"}; +const std::vector<std::string> Routine::routines_ger = {"GER", "GERC", "GERU", "HER", "HER2", "HPR", "HPR2", "SPR", "SPR2", "SYR", "SYR2"}; +const std::vector<std::string> Routine::routines_gemv = {"GBMV", "GEMV", "HBMV", "HEMV", "HPMV", "SBMV", "SPMV", "SYMV", "TMBV", "TPMV", "TRMV", "TRSV"}; +const std::vector<std::string> Routine::routines_gemm = {"GEMM", "HEMM", "SYMM", "TRMM"}; +const std::vector<std::string> Routine::routines_gemm_syrk = {"GEMM", "HEMM", "HER2K", "HERK", "SYMM", "SYR2K", "SYRK", "TRMM", "TRSM"}; +const std::vector<std::string> Routine::routines_trsm = {"TRSM"}; +const std::unordered_map<std::string, const std::vector<std::string>> Routine::routines_by_kernel = { + {"Xaxpy", routines_axpy}, + {"Xdot", routines_dot}, + {"Xgemv", routines_gemv}, + {"XgemvFast", routines_gemv}, + {"XgemvFastRot", routines_gemv}, + {"Xtrsv", routines_gemv}, + {"Xger", routines_ger}, + {"Copy", routines_gemm_syrk}, + {"Pad", routines_gemm_syrk}, + {"Transpose", routines_gemm_syrk}, + {"Padtranspose", routines_gemm_syrk}, + {"Xgemm", routines_gemm_syrk}, + {"XgemmDirect", routines_gemm}, + {"KernelSelection", routines_gemm}, + {"Invert", routines_trsm}, +}; +// ================================================================================================= + // The constructor does all heavy work, errors are returned as exceptions Routine::Routine(Queue &queue, EventPointer event, const std::string &name, - const std::vector<std::string> &routines, const Precision precision, + const std::vector<std::string> &kernel_names, const Precision precision, const std::vector<const Database::DatabaseEntry*> &userDatabase, std::initializer_list<const char *> source): precision_(precision), routine_name_(name), + kernel_names_(kernel_names), queue_(queue), event_(event), context_(queue_.GetContext()), device_(queue_.GetDevice()), - device_name_(device_.Name()) { + device_name_(device_.Name()), + db_(kernel_names) { - InitDatabase(routines, userDatabase); + InitDatabase(userDatabase); InitProgram(source); } -void Routine::InitDatabase(const std::vector<std::string> &routines, - const std::vector<const Database::DatabaseEntry*> &userDatabase) { +void Routine::InitDatabase(const std::vector<const Database::DatabaseEntry*> &userDatabase) { + for (const auto &kernel_name : kernel_names_) { - // Queries the cache to see whether or not the kernel parameter database is already there - bool has_db; - db_ = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, routines }, - &has_db); - if (has_db) { return; } + // Queries the cache to see whether or not the kernel parameter database is already there + bool has_db; + db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, kernel_name }, + &has_db); + if (has_db) { continue; } - // Builds the parameter database for this device and routine set and stores it in the cache - db_ = Database(device_, routines, precision_, userDatabase); - DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, routines }, - Database{ db_ }); + // Builds the parameter database for this device and routine set and stores it in the cache + db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase); + DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, kernel_name }, + Database{ db_(kernel_name) }); + } } void Routine::InitProgram(std::initializer_list<const char *> source) { @@ -96,7 +126,10 @@ void Routine::InitProgram(std::initializer_list<const char *> source) { } // Collects the parameters for this device in the form of defines, and adds the precision - auto source_string = db_.GetDefines(); + auto source_string = std::string{""}; + for (const auto &kernel_name : kernel_names_) { + source_string += db_(kernel_name).GetDefines(); + } source_string += "#define PRECISION "+ToString(static_cast<int>(precision_))+"\n"; // Adds the name of the routine as a define |