summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-02-12 12:02:39 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-02-12 12:02:39 +0100
commit345a5feb9a18641ceffd7ce5e0cb9387686cf32c (patch)
treecdae4f7fd42f4b23349867243434a1d648ed3e41 /src
parentfaa842b927ede6df1763607e3732151162875d73 (diff)
Split the database into several smaller cached per-kernel databases (in preparation of per-kernel database overrides)
Diffstat (limited to 'src')
-rw-r--r--src/cache.hpp6
-rw-r--r--src/database/database.cpp23
-rw-r--r--src/database/database.hpp30
-rw-r--r--src/routine.cpp36
-rw-r--r--src/routine.hpp8
-rw-r--r--src/routines/common.hpp2
6 files changed, 67 insertions, 38 deletions
diff --git a/src/cache.hpp b/src/cache.hpp
index f7ca3dc8..694de839 100644
--- a/src/cache.hpp
+++ b/src/cache.hpp
@@ -93,9 +93,9 @@ extern template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const;
class Database;
// The key struct for the cache of database maps.
-// Order of fields: precision, device_name, routines (smaller fields first)
-typedef std::tuple<Precision, std::string, std::vector<std::string>> DatabaseKey;
-typedef std::tuple<const Precision &, const std::string &, const std::vector<std::string> &> DatabaseKeyRef;
+// Order of fields: precision, device_name, kernel_name (smaller fields first)
+typedef std::tuple<Precision, std::string, std::string> DatabaseKey;
+typedef std::tuple<const Precision &, const std::string &, const std::string &> DatabaseKeyRef;
typedef Cache<DatabaseKey, Database> DatabaseCache;
diff --git a/src/database/database.cpp b/src/database/database.cpp
index c1cb9d56..8019d558 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -63,7 +63,7 @@ const std::unordered_map<std::string, std::string> Database::kVendorNames{
// Constructor, computing device properties and populating the parameter-vector from the database.
// This takes an optional overlay database in case of custom tuning or custom kernels.
-Database::Database(const Device &device, const std::vector<std::string> &kernels,
+Database::Database(const Device &device, const std::string &kernel_name,
const Precision precision, const std::vector<const DatabaseEntry*> &overlay):
parameters_(std::make_shared<Parameters>()) {
@@ -79,20 +79,17 @@ Database::Database(const Device &device, const std::vector<std::string> &kernels
}
}
- // Iterates over all kernels to include, and retrieves the parameters for each of them
- for (auto &kernel: kernels) {
- auto search_result = ParametersPtr{};
-
- for (auto &db: { database, overlay}) {
- search_result = Search(kernel, device_type, device_vendor, device_name, precision, db);
- if (search_result) {
- parameters_->insert(search_result->begin(), search_result->end());
- break;
- }
+ // Searches potentially multiple databases
+ auto search_result = ParametersPtr{};
+ for (auto &db: { overlay, database}) {
+ search_result = Search(kernel_name, device_type, device_vendor, device_name, precision, db);
+ if (search_result) {
+ parameters_->insert(search_result->begin(), search_result->end());
+ break;
}
-
- if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); }
}
+
+ if (!search_result) { throw RuntimeErrorCode(StatusCode::kDatabaseError); }
}
// =================================================================================================
diff --git a/src/database/database.hpp b/src/database/database.hpp
index 87c12293..b6760ec3 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -75,11 +75,12 @@ class Database {
Database() = default;
// The constructor with a user-provided database overlay (potentially an empty vector)
- explicit Database(const Device &device, const std::vector<std::string> &routines,
+ explicit Database(const Device &device, const std::string &kernel_name,
const Precision precision, const std::vector<const DatabaseEntry*> &overlay);
// Accessor of values by key
- size_t operator[](const std::string key) const { return parameters_->find(key)->second; }
+ size_t operator[](const std::string &key) const { return parameters_->find(key)->second; }
+ bool exists(const std::string &key) const { return (parameters_->count(key) == 1); }
// Obtain a list of OpenCL pre-processor defines based on the parameters
std::string GetDefines() const;
@@ -96,6 +97,31 @@ class Database {
};
// =================================================================================================
+
+// Multiple databases together in a map
+class Databases {
+ public:
+
+ explicit Databases(const std::vector<std::string> &kernel_names): kernel_names_(kernel_names) { }
+
+ // Database accessor
+ Database& operator()(const std::string &kernel_name) { return databases_[kernel_name]; }
+
+ // Retrieves a parameter from the database
+ size_t operator[](const std::string &key) const {
+ for (const auto &kernel_name : kernel_names_) {
+ const auto &kernel_db = databases_.find(kernel_name)->second;
+ if (kernel_db.exists(key)) { return kernel_db[key]; }
+ }
+ throw RuntimeErrorCode(StatusCode::kDatabaseError);
+ }
+
+ private:
+ const std::vector<std::string> kernel_names_;
+ std::unordered_map<std::string, Database> databases_;
+};
+
+// =================================================================================================
} // namespace clblast
// CLBLAST_DATABASE_H_
diff --git a/src/routine.cpp b/src/routine.cpp
index 4fe04a60..854c7046 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -23,34 +23,37 @@ namespace clblast {
// The constructor does all heavy work, errors are returned as exceptions
Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
- const std::vector<std::string> &routines, const Precision precision,
+ const std::vector<std::string> &kernel_names, const Precision precision,
const std::vector<const Database::DatabaseEntry*> &userDatabase,
std::initializer_list<const char *> source):
precision_(precision),
routine_name_(name),
+ kernel_names_(kernel_names),
queue_(queue),
event_(event),
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
- device_name_(device_.Name()) {
+ device_name_(device_.Name()),
+ db_(kernel_names) {
- InitDatabase(routines, userDatabase);
+ InitDatabase(userDatabase);
InitProgram(source);
}
-void Routine::InitDatabase(const std::vector<std::string> &routines,
- const std::vector<const Database::DatabaseEntry*> &userDatabase) {
+void Routine::InitDatabase(const std::vector<const Database::DatabaseEntry*> &userDatabase) {
+ for (const auto &kernel_name : kernel_names_) {
- // Queries the cache to see whether or not the kernel parameter database is already there
- bool has_db;
- db_ = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, routines },
- &has_db);
- if (has_db) { return; }
+ // Queries the cache to see whether or not the kernel parameter database is already there
+ bool has_db;
+ db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, kernel_name },
+ &has_db);
+ if (has_db) { continue; }
- // Builds the parameter database for this device and routine set and stores it in the cache
- db_ = Database(device_, routines, precision_, userDatabase);
- DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, routines },
- Database{ db_ });
+ // Builds the parameter database for this device and routine set and stores it in the cache
+ db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
+ DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, kernel_name },
+ Database{ db_(kernel_name) });
+ }
}
void Routine::InitProgram(std::initializer_list<const char *> source) {
@@ -96,7 +99,10 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
}
// Collects the parameters for this device in the form of defines, and adds the precision
- auto source_string = db_.GetDefines();
+ auto source_string = std::string{""};
+ for (const auto &kernel_name : kernel_names_) {
+ source_string += db_(kernel_name).GetDefines();
+ }
source_string += "#define PRECISION "+ToString(static_cast<int>(precision_))+"\n";
// Adds the name of the routine as a define
diff --git a/src/routine.hpp b/src/routine.hpp
index f366e4d9..ba8b9f60 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -48,16 +48,16 @@ class Routine {
void InitProgram(std::initializer_list<const char *> source);
// Initializes db_, fetching cached database or building one
- void InitDatabase(const std::vector<std::string> &routines,
- const std::vector<const Database::DatabaseEntry*> &userDatabase);
+ void InitDatabase(const std::vector<const Database::DatabaseEntry*> &userDatabase);
protected:
// Non-static variable for the precision
const Precision precision_;
- // The routine's name
+ // The routine's name and the corresponding kernels
const std::string routine_name_;
+ const std::vector<std::string> kernel_names_;
// The OpenCL objects, accessible only from derived classes
Queue queue_;
@@ -72,7 +72,7 @@ class Routine {
Program program_;
// Connection to the database for all the device-specific parameters
- Database db_;
+ Databases db_;
};
// =================================================================================================
diff --git a/src/routines/common.hpp b/src/routines/common.hpp
index 7c211c0d..d268d58b 100644
--- a/src/routines/common.hpp
+++ b/src/routines/common.hpp
@@ -37,7 +37,7 @@ void RunKernel(Kernel &kernel, Queue &queue, const Device &device,
// to write to symmetric and triangular matrices through optional arguments.
template <typename T>
void PadCopyTransposeMatrix(Queue &queue, const Device &device,
- const Database &db,
+ const Databases &db,
EventPointer event, const std::vector<Event> &waitForEvents,
const size_t src_one, const size_t src_two,
const size_t src_ld, const size_t src_offset,