summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cache.cpp6
-rw-r--r--src/cache.hpp14
-rw-r--r--src/routine.cpp24
-rw-r--r--src/routine.hpp13
4 files changed, 54 insertions, 3 deletions
diff --git a/src/cache.cpp b/src/cache.cpp
index 2b91abf1..c5cc6a4d 100644
--- a/src/cache.cpp
+++ b/src/cache.cpp
@@ -15,6 +15,7 @@
#include <vector>
#include <mutex>
+#include "database/database.hpp"
#include "cache.hpp"
namespace clblast {
@@ -89,4 +90,9 @@ template class Cache<ProgramKey, Program>;
template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const;
// =================================================================================================
+
+template class Cache<DatabaseKey, Database>;
+template Database DatabaseCache::Get(const DatabaseKeyRef &, bool *) const;
+
+// =================================================================================================
} // namespace clblast
diff --git a/src/cache.hpp b/src/cache.hpp
index 25d6f076..c3675f07 100644
--- a/src/cache.hpp
+++ b/src/cache.hpp
@@ -86,6 +86,20 @@ extern template class Cache<ProgramKey, Program>;
extern template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const;
// =================================================================================================
+
+class Database;
+
+// The key struct for the cache of database maps.
+// Order of fields: precision, device_name, routines (smaller fields first)
+typedef std::tuple<Precision, std::string, std::vector<std::string>> DatabaseKey;
+typedef std::tuple<const Precision &, const std::string &, const std::vector<std::string> &> DatabaseKeyRef;
+
+typedef Cache<DatabaseKey, Database> DatabaseCache;
+
+extern template class Cache<DatabaseKey, Database>;
+extern template Database DatabaseCache::Get(const DatabaseKeyRef &, bool *) const;
+
+// =================================================================================================
} // namespace clblast
// CLBLAST_CACHE_H_
diff --git a/src/routine.cpp b/src/routine.cpp
index 75e4ea89..b36a1ecd 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -32,8 +32,28 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
event_(event),
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
- device_name_(device_.Name()),
- db_(queue_, routines, precision_, userDatabase) {
+ device_name_(device_.Name()) {
+
+ InitDatabase(routines, userDatabase);
+ InitProgram(source);
+}
+
+void Routine::InitDatabase(const std::vector<std::string> &routines,
+ const std::vector<const Database::DatabaseEntry*> &userDatabase) {
+
+ // Queries the cache to see whether or not the kernel parameter database is already there
+ bool has_db;
+ db_ = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, routines },
+ &has_db);
+ if (has_db) { return; }
+
+ // Builds the parameter database for this device and routine set and stores it in the cache
+ db_ = Database(queue_, routines, precision_, userDatabase);
+ DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, routines },
+ Database{ db_ });
+}
+
+void Routine::InitProgram(std::initializer_list<const char *> source) {
// Queries the cache to see whether or not the program (context-specific) is already there
bool has_program;
diff --git a/src/routine.hpp b/src/routine.hpp
index 8e9fd54d..f366e4d9 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -35,11 +35,22 @@ class Routine {
// Base class constructor. The user database is an optional extra database to override the
// built-in database.
// All heavy preparation work is done inside this constructor.
+ // NOTE: the caller must provide the same userDatabase for each combination of device, precision
+ // and routine list, otherwise the caching logic will break.
explicit Routine(Queue &queue, EventPointer event, const std::string &name,
const std::vector<std::string> &routines, const Precision precision,
const std::vector<const Database::DatabaseEntry*> &userDatabase,
std::initializer_list<const char *> source);
+ private:
+
+ // Initializes program_, fetching cached program or building one
+ void InitProgram(std::initializer_list<const char *> source);
+
+ // Initializes db_, fetching cached database or building one
+ void InitDatabase(const std::vector<std::string> &routines,
+ const std::vector<const Database::DatabaseEntry*> &userDatabase);
+
protected:
// Non-static variable for the precision
@@ -61,7 +72,7 @@ class Routine {
Program program_;
// Connection to the database for all the device-specific parameters
- const Database db_;
+ Database db_;
};
// =================================================================================================