summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cache.cpp32
-rw-r--r--src/cache.hpp11
-rw-r--r--src/clblast.cpp43
-rw-r--r--src/clblast_c.cpp21
-rw-r--r--src/database/database.cpp32
-rw-r--r--src/database/database.hpp33
-rw-r--r--src/routine.cpp61
-rw-r--r--src/routine.hpp18
-rw-r--r--src/routines/common.hpp2
-rw-r--r--src/routines/level2/xgemv.cpp2
10 files changed, 215 insertions, 40 deletions
diff --git a/src/cache.cpp b/src/cache.cpp
index c5cc6a4d..4b74b0a1 100644
--- a/src/cache.cpp
+++ b/src/cache.cpp
@@ -65,6 +65,37 @@ void Cache<Key, Value>::Store(Key &&key, Value &&value) {
}
template <typename Key, typename Value>
+void Cache<Key, Value>::Remove(const Key &key) {
+ std::lock_guard<std::mutex> lock(cache_mutex_);
+#if __cplusplus >= 201402L
+ cache_.erase(key);
+#else
+ auto it = cache_.begin();
+ while (it != cache_.end()) {
+ if ((*it).first == key) {
+ it = cache_.erase(it);
+ }
+ else ++it;
+ }
+#endif
+}
+
+template <typename Key, typename Value>
+template <int I1, int I2>
+void Cache<Key, Value>::RemoveBySubset(const Key &key) {
+ std::lock_guard<std::mutex> lock(cache_mutex_);
+ auto it = cache_.begin();
+ while (it != cache_.end()) {
+ const auto current_key = (*it).first;
+ if ((std::get<I1>(key) == std::get<I1>(current_key)) &&
+ (std::get<I2>(key) == std::get<I2>(current_key))) {
+ it = cache_.erase(it);
+ }
+ else ++it;
+ }
+}
+
+template <typename Key, typename Value>
void Cache<Key, Value>::Invalidate() {
std::lock_guard<std::mutex> lock(cache_mutex_);
@@ -88,6 +119,7 @@ template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const;
template class Cache<ProgramKey, Program>;
template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const;
+template void ProgramCache::RemoveBySubset<1, 2>(const ProgramKey &); // precision and routine name
// =================================================================================================
diff --git a/src/cache.hpp b/src/cache.hpp
index c3675f07..694de839 100644
--- a/src/cache.hpp
+++ b/src/cache.hpp
@@ -42,6 +42,10 @@ public:
void Store(Key &&key, Value &&value);
void Invalidate();
+ // Removes all entries with a given key
+ void Remove(const Key &key);
+ template <int I1, int I2> void RemoveBySubset(const Key &key); // currently supports 2 indices
+
static Cache<Key, Value> &Instance();
private:
@@ -72,7 +76,6 @@ typedef Cache<BinaryKey, std::string> BinaryCache;
extern template class Cache<BinaryKey, std::string>;
extern template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const;
-
// =================================================================================================
// The key struct for the cache of compiled OpenCL programs (context-dependent)
@@ -90,9 +93,9 @@ extern template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const;
class Database;
// The key struct for the cache of database maps.
-// Order of fields: precision, device_name, routines (smaller fields first)
-typedef std::tuple<Precision, std::string, std::vector<std::string>> DatabaseKey;
-typedef std::tuple<const Precision &, const std::string &, const std::vector<std::string> &> DatabaseKeyRef;
+// Order of fields: precision, device_name, kernel_name (smaller fields first)
+typedef std::tuple<Precision, std::string, std::string> DatabaseKey;
+typedef std::tuple<const Precision &, const std::string &, const std::string &> DatabaseKeyRef;
typedef Cache<DatabaseKey, Database> DatabaseCache;
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 35f3f552..a8e4d084 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2254,4 +2254,47 @@ StatusCode FillCache(const cl_device_id device) {
}
// =================================================================================================
+
+// Overrides the tuning parameters for this device-precision-kernel combination
+StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name,
+ const Precision precision,
+ const std::unordered_map<std::string,size_t> &parameters) {
+ try {
+
+ // Retrieves the device name
+ const auto device_cpp = Device(device);
+ const auto device_name = device_cpp.Name();
+
+ // Retrieves the current database values to verify whether the new ones are complete
+ auto in_cache = false;
+ const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision, device_name, kernel_name }, &in_cache);
+ if (!in_cache) { return StatusCode::kInvalidOverrideKernel; }
+ for (const auto &current_param : current_database.GetParameterNames()) {
+ if (parameters.find(current_param) == parameters.end()) {
+ return StatusCode::kMissingOverrideParameter;
+ }
+ }
+
+ // Clears the existing program & binary cache for routines with the target kernel
+ const auto routine_names = Routine::routines_by_kernel.at(kernel_name);
+ for (const auto &routine_name : routine_names) {
+ ProgramCache::Instance().RemoveBySubset<1, 2>(ProgramKey{nullptr, precision, routine_name});
+ BinaryCache::Instance().Remove(BinaryKey{precision, routine_name, device_name});
+ }
+
+ // Creates a small custom database based on the provided parameters
+ const auto database_device = Database::DatabaseDevice{"default", parameters};
+ const auto database_vendor = Database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_device}};
+ const auto database_entry = Database::DatabaseEntry{kernel_name, precision, {database_vendor}};
+ const auto database = Database(device_cpp, kernel_name, precision, {&database_entry});
+
+ // Removes the old database entry and stores the new one in the cache
+ DatabaseCache::Instance().Remove(DatabaseKey{ precision, device_name, kernel_name });
+ DatabaseCache::Instance().Store(DatabaseKey{ precision, device_name, kernel_name }, Database(database));
+
+ } catch (...) { return DispatchException(); }
+ return StatusCode::kSuccess;
+}
+
+// =================================================================================================
} // namespace clblast
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index e4f2b3ed..de431fa4 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -12,6 +12,7 @@
// =================================================================================================
#include <string>
+#include <unordered_map>
#include "utilities/utilities.hpp"
#include "clblast_c.h"
@@ -3484,3 +3485,23 @@ CLBlastStatusCode CLBlastFillCache(const cl_device_id device) {
}
// =================================================================================================
+
+// Overrides the tuning parameters for this device-precision-kernel combination
+CLBlastStatusCode PUBLIC_API CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name,
+ const CLBlastPrecision precision, const size_t num_parameters,
+ const char** parameters_names, const size_t* parameters_values) {
+ try {
+ const auto kernel_name_cpp = std::string(kernel_name);
+ const auto precision_cpp = static_cast<clblast::Precision>(precision);
+ auto parameters = std::unordered_map<std::string, size_t>();
+ for (auto i = size_t{0}; i < num_parameters; ++i) {
+ const auto parameter_name = std::string(parameters_names[i]);
+ const auto parameter_value = parameters_values[i];
+ parameters[parameter_name] = parameter_value;
+ }
+ const auto status = clblast::OverrideParameters(device, kernel_name_cpp, precision_cpp, parameters);
+ return static_cast<CLBlastStatusCode>(status);
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
+// =================================================================================================
diff --git a/src/database/database.cpp b/src/database/database.cpp
index c1cb9d56..02d0b139 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -63,7 +63,7 @@ const std::unordered_map<std::string, std::string> Database::kVendorNames{
// Constructor, computing device properties and populating the parameter-vector from the database.
// This takes an optional overlay database in case of custom tuning or custom kernels.
-Database::Database(const Device &device, const std::vector<std::string> &kernels,
+Database::Database(const Device &device, const std::string &kernel_name,
const Precision precision, const std::vector<const DatabaseEntry*> &overlay):
parameters_(std::make_shared<Parameters>()) {
@@ -79,20 +79,17 @@ Database::Database(const Device &device, const std::vector<std::string> &kernels
}
}
- // Iterates over all kernels to include, and retrieves the parameters for each of them
- for (auto &kernel: kernels) {
- auto search_result = ParametersPtr{};
-
- for (auto &db: { database, overlay}) {
- search_result = Search(kernel, device_type, device_vendor, device_name, precision, db);
- if (search_result) {
- parameters_->insert(search_result->begin(), search_result->end());
- break;
- }
+ // Searches potentially multiple databases
+ auto search_result = ParametersPtr{};
+ for (auto &db: { overlay, database}) {
+ search_result = Search(kernel_name, device_type, device_vendor, device_name, precision, db);
+ if (search_result) {
+ parameters_->insert(search_result->begin(), search_result->end());
+ break;
}
-
- if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); }
}
+
+ if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); }
}
// =================================================================================================
@@ -106,6 +103,15 @@ std::string Database::GetDefines() const {
return defines;
}
+// Retrieves the names of all the parameters
+std::vector<std::string> Database::GetParameterNames() const {
+ auto parameter_names = std::vector<std::string>();
+ for (auto &parameter: *parameters_) {
+ parameter_names.push_back(parameter.first);
+ }
+ return parameter_names;
+}
+
// =================================================================================================
// Searches a particular database for the right kernel and precision
diff --git a/src/database/database.hpp b/src/database/database.hpp
index 87c12293..b34e0d8a 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -75,15 +75,19 @@ class Database {
Database() = default;
// The constructor with a user-provided database overlay (potentially an empty vector)
- explicit Database(const Device &device, const std::vector<std::string> &routines,
+ explicit Database(const Device &device, const std::string &kernel_name,
const Precision precision, const std::vector<const DatabaseEntry*> &overlay);
// Accessor of values by key
- size_t operator[](const std::string key) const { return parameters_->find(key)->second; }
+ size_t operator[](const std::string &key) const { return parameters_->find(key)->second; }
+ bool exists(const std::string &key) const { return (parameters_->count(key) == 1); }
// Obtain a list of OpenCL pre-processor defines based on the parameters
std::string GetDefines() const;
+ // Retrieves the names of all the parameters
+ std::vector<std::string> GetParameterNames() const;
+
private:
// Search method for a specified database, returning pointer (possibly a nullptr)
ParametersPtr Search(const std::string &this_kernel, const std::string &this_type,
@@ -96,6 +100,31 @@ class Database {
};
// =================================================================================================
+
+// Multiple databases together in a map
+class Databases {
+ public:
+
+ explicit Databases(const std::vector<std::string> &kernel_names): kernel_names_(kernel_names) { }
+
+ // Database accessor
+ Database& operator()(const std::string &kernel_name) { return databases_[kernel_name]; }
+
+ // Retrieves a parameter from the database
+ size_t operator[](const std::string &key) const {
+ for (const auto &kernel_name : kernel_names_) {
+ const auto &kernel_db = databases_.find(kernel_name)->second;
+ if (kernel_db.exists(key)) { return kernel_db[key]; }
+ }
+ throw RuntimeErrorCode(StatusCode::kDatabaseError);
+ }
+
+ private:
+ const std::vector<std::string> kernel_names_;
+ std::unordered_map<std::string, Database> databases_;
+};
+
+// =================================================================================================
} // namespace clblast
// CLBLAST_DATABASE_H_
diff --git a/src/routine.cpp b/src/routine.cpp
index 4fe04a60..3cd045c8 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -21,36 +21,64 @@
namespace clblast {
// =================================================================================================
+// For each kernel this map contains a list of routines it is used in
+const std::vector<std::string> Routine::routines_axpy = {"AXPY", "COPY", "SCAL", "SWAP"};
+const std::vector<std::string> Routine::routines_dot = {"AMAX", "ASUM", "DOT", "DOTC", "DOTU", "MAX", "MIN", "NRM2", "SUM"};
+const std::vector<std::string> Routine::routines_ger = {"GER", "GERC", "GERU", "HER", "HER2", "HPR", "HPR2", "SPR", "SPR2", "SYR", "SYR2"};
+const std::vector<std::string> Routine::routines_gemv = {"GBMV", "GEMV", "HBMV", "HEMV", "HPMV", "SBMV", "SPMV", "SYMV", "TMBV", "TPMV", "TRMV"};
+const std::vector<std::string> Routine::routines_gemm = {"GEMM", "HEMM", "SYMM", "TRMM"};
+const std::vector<std::string> Routine::routines_gemm_syrk = {"GEMM", "HEMM", "HER2K", "HERK", "SYMM", "SYR2K", "SYRK", "TRMM"};
+const std::unordered_map<std::string, const std::vector<std::string>> Routine::routines_by_kernel = {
+ {"Xaxpy", routines_axpy},
+ {"Xdot", routines_dot},
+ {"Xgemv", routines_gemv},
+ {"XgemvFast", routines_gemv},
+ {"XgemvFastRot", routines_gemv},
+ {"Xgemv", {}},
+ {"Xger", routines_ger},
+ {"Copy", routines_gemm_syrk},
+ {"Pad", routines_gemm_syrk},
+ {"Transpose", routines_gemm_syrk},
+ {"Padtranspose", routines_gemm_syrk},
+ {"Xgemm", routines_gemm_syrk},
+ {"XgemmDirect", routines_gemm},
+ {"KernelSelection", routines_gemm},
+};
+// =================================================================================================
+
// The constructor does all heavy work, errors are returned as exceptions
Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
- const std::vector<std::string> &routines, const Precision precision,
+ const std::vector<std::string> &kernel_names, const Precision precision,
const std::vector<const Database::DatabaseEntry*> &userDatabase,
std::initializer_list<const char *> source):
precision_(precision),
routine_name_(name),
+ kernel_names_(kernel_names),
queue_(queue),
event_(event),
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
- device_name_(device_.Name()) {
+ device_name_(device_.Name()),
+ db_(kernel_names) {
- InitDatabase(routines, userDatabase);
+ InitDatabase(userDatabase);
InitProgram(source);
}
-void Routine::InitDatabase(const std::vector<std::string> &routines,
- const std::vector<const Database::DatabaseEntry*> &userDatabase) {
+void Routine::InitDatabase(const std::vector<const Database::DatabaseEntry*> &userDatabase) {
+ for (const auto &kernel_name : kernel_names_) {
- // Queries the cache to see whether or not the kernel parameter database is already there
- bool has_db;
- db_ = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, routines },
- &has_db);
- if (has_db) { return; }
+ // Queries the cache to see whether or not the kernel parameter database is already there
+ bool has_db;
+ db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, kernel_name },
+ &has_db);
+ if (has_db) { continue; }
- // Builds the parameter database for this device and routine set and stores it in the cache
- db_ = Database(device_, routines, precision_, userDatabase);
- DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, routines },
- Database{ db_ });
+ // Builds the parameter database for this device and routine set and stores it in the cache
+ db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
+ DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, kernel_name },
+ Database{ db_(kernel_name) });
+ }
}
void Routine::InitProgram(std::initializer_list<const char *> source) {
@@ -96,7 +124,10 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
}
// Collects the parameters for this device in the form of defines, and adds the precision
- auto source_string = db_.GetDefines();
+ auto source_string = std::string{""};
+ for (const auto &kernel_name : kernel_names_) {
+ source_string += db_(kernel_name).GetDefines();
+ }
source_string += "#define PRECISION "+ToString(static_cast<int>(precision_))+"\n";
// Adds the name of the routine as a define
diff --git a/src/routine.hpp b/src/routine.hpp
index f366e4d9..622a1c0d 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -18,6 +18,7 @@
#include <string>
#include <vector>
+#include <unordered_map>
#include "utilities/utilities.hpp"
#include "cache.hpp"
@@ -42,22 +43,31 @@ class Routine {
const std::vector<const Database::DatabaseEntry*> &userDatabase,
std::initializer_list<const char *> source);
+ // List of kernel-routine look-ups
+ static const std::vector<std::string> routines_axpy;
+ static const std::vector<std::string> routines_dot;
+ static const std::vector<std::string> routines_ger;
+ static const std::vector<std::string> routines_gemv;
+ static const std::vector<std::string> routines_gemm;
+ static const std::vector<std::string> routines_gemm_syrk;
+ static const std::unordered_map<std::string, const std::vector<std::string>> routines_by_kernel;
+
private:
// Initializes program_, fetching cached program or building one
void InitProgram(std::initializer_list<const char *> source);
// Initializes db_, fetching cached database or building one
- void InitDatabase(const std::vector<std::string> &routines,
- const std::vector<const Database::DatabaseEntry*> &userDatabase);
+ void InitDatabase(const std::vector<const Database::DatabaseEntry*> &userDatabase);
protected:
// Non-static variable for the precision
const Precision precision_;
- // The routine's name
+ // The routine's name and the corresponding kernels
const std::string routine_name_;
+ const std::vector<std::string> kernel_names_;
// The OpenCL objects, accessible only from derived classes
Queue queue_;
@@ -72,7 +82,7 @@ class Routine {
Program program_;
// Connection to the database for all the device-specific parameters
- Database db_;
+ Databases db_;
};
// =================================================================================================
diff --git a/src/routines/common.hpp b/src/routines/common.hpp
index 7c211c0d..d268d58b 100644
--- a/src/routines/common.hpp
+++ b/src/routines/common.hpp
@@ -37,7 +37,7 @@ void RunKernel(Kernel &kernel, Queue &queue, const Device &device,
// to write to symmetric and triangular matrices through optional arguments.
template <typename T>
void PadCopyTransposeMatrix(Queue &queue, const Device &device,
- const Database &db,
+ const Databases &db,
EventPointer event, const std::vector<Event> &waitForEvents,
const size_t src_one, const size_t src_two,
const size_t src_ld, const size_t src_offset,
diff --git a/src/routines/level2/xgemv.cpp b/src/routines/level2/xgemv.cpp
index 9e9c2db4..aae66798 100644
--- a/src/routines/level2/xgemv.cpp
+++ b/src/routines/level2/xgemv.cpp
@@ -22,7 +22,7 @@ namespace clblast {
// Constructor: forwards to base class constructor
template <typename T>
Xgemv<T>::Xgemv(Queue &queue, EventPointer event, const std::string &name):
- Routine(queue, event, name, {"Pad", "Xgemv", "XgemvFast", "XgemvFastRot"}, PrecisionValue<T>(), {}, {
+ Routine(queue, event, name, {"Xgemv", "XgemvFast", "XgemvFastRot"}, PrecisionValue<T>(), {}, {
#include "../../kernels/level2/xgemv.opencl"
#include "../../kernels/level2/xgemv_fast.opencl"
}) {