summaryrefslogtreecommitdiff
path: root/src/routine.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routine.cpp')
-rw-r--r--src/routine.cpp63
1 files changed, 48 insertions, 15 deletions
diff --git a/src/routine.cpp b/src/routine.cpp
index 4fe04a60..b5823bc9 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -21,36 +21,66 @@
namespace clblast {
// =================================================================================================
+// For each kernel this map contains a list of routines it is used in
+const std::vector<std::string> Routine::routines_axpy = {"AXPY", "COPY", "SCAL", "SWAP"};
+const std::vector<std::string> Routine::routines_dot = {"AMAX", "ASUM", "DOT", "DOTC", "DOTU", "MAX", "MIN", "NRM2", "SUM"};
+const std::vector<std::string> Routine::routines_ger = {"GER", "GERC", "GERU", "HER", "HER2", "HPR", "HPR2", "SPR", "SPR2", "SYR", "SYR2"};
+const std::vector<std::string> Routine::routines_gemv = {"GBMV", "GEMV", "HBMV", "HEMV", "HPMV", "SBMV", "SPMV", "SYMV", "TMBV", "TPMV", "TRMV", "TRSV"};
+const std::vector<std::string> Routine::routines_gemm = {"GEMM", "HEMM", "SYMM", "TRMM"};
+const std::vector<std::string> Routine::routines_gemm_syrk = {"GEMM", "HEMM", "HER2K", "HERK", "SYMM", "SYR2K", "SYRK", "TRMM", "TRSM"};
+const std::vector<std::string> Routine::routines_trsm = {"TRSM"};
+const std::unordered_map<std::string, const std::vector<std::string>> Routine::routines_by_kernel = {
+ {"Xaxpy", routines_axpy},
+ {"Xdot", routines_dot},
+ {"Xgemv", routines_gemv},
+ {"XgemvFast", routines_gemv},
+ {"XgemvFastRot", routines_gemv},
+ {"Xtrsv", routines_gemv},
+ {"Xger", routines_ger},
+ {"Copy", routines_gemm_syrk},
+ {"Pad", routines_gemm_syrk},
+ {"Transpose", routines_gemm_syrk},
+ {"Padtranspose", routines_gemm_syrk},
+ {"Xgemm", routines_gemm_syrk},
+ {"XgemmDirect", routines_gemm},
+ {"KernelSelection", routines_gemm},
+ {"Invert", routines_trsm},
+};
+// =================================================================================================
+
// The constructor does all heavy work, errors are returned as exceptions
Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
- const std::vector<std::string> &routines, const Precision precision,
+ const std::vector<std::string> &kernel_names, const Precision precision,
const std::vector<const Database::DatabaseEntry*> &userDatabase,
std::initializer_list<const char *> source):
precision_(precision),
routine_name_(name),
+ kernel_names_(kernel_names),
queue_(queue),
event_(event),
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
- device_name_(device_.Name()) {
+ device_name_(device_.Name()),
+ db_(kernel_names) {
- InitDatabase(routines, userDatabase);
+ InitDatabase(userDatabase);
InitProgram(source);
}
-void Routine::InitDatabase(const std::vector<std::string> &routines,
- const std::vector<const Database::DatabaseEntry*> &userDatabase) {
+void Routine::InitDatabase(const std::vector<const Database::DatabaseEntry*> &userDatabase) {
+ for (const auto &kernel_name : kernel_names_) {
- // Queries the cache to see whether or not the kernel parameter database is already there
- bool has_db;
- db_ = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, routines },
- &has_db);
- if (has_db) { return; }
+ // Queries the cache to see whether or not the kernel parameter database is already there
+ bool has_db;
+ db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, kernel_name },
+ &has_db);
+ if (has_db) { continue; }
- // Builds the parameter database for this device and routine set and stores it in the cache
- db_ = Database(device_, routines, precision_, userDatabase);
- DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, routines },
- Database{ db_ });
+ // Builds the parameter database for this device and routine set and stores it in the cache
+ db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
+ DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, kernel_name },
+ Database{ db_(kernel_name) });
+ }
}
void Routine::InitProgram(std::initializer_list<const char *> source) {
@@ -96,7 +126,10 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
}
// Collects the parameters for this device in the form of defines, and adds the precision
- auto source_string = db_.GetDefines();
+ auto source_string = std::string{""};
+ for (const auto &kernel_name : kernel_names_) {
+ source_string += db_(kernel_name).GetDefines();
+ }
source_string += "#define PRECISION "+ToString(static_cast<int>(precision_))+"\n";
// Adds the name of the routine as a define