summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cache.hpp6
-rw-r--r--src/clblast.cpp9
-rw-r--r--src/clpp11.hpp1
-rw-r--r--src/routine.cpp13
-rw-r--r--src/routine.hpp4
5 files changed, 17 insertions, 16 deletions
diff --git a/src/cache.hpp b/src/cache.hpp
index ed693ea3..f6a948b6 100644
--- a/src/cache.hpp
+++ b/src/cache.hpp
@@ -93,9 +93,9 @@ extern template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const;
class Database;
// The key struct for the cache of database maps.
-// Order of fields: precision, device_name, kernel_name (smaller fields first)
-typedef std::tuple<Precision, std::string, std::string> DatabaseKey;
-typedef std::tuple<const Precision &, const std::string &, const std::string &> DatabaseKeyRef;
+// Order of fields: platform_id, device_id, precision, kernel_name (smaller fields first)
+typedef std::tuple<cl_platform_id, cl_device_id, Precision, std::string> DatabaseKey;
+typedef std::tuple<const cl_platform_id &, const cl_device_id &, const Precision &, const std::string &> DatabaseKeyRef;
typedef Cache<DatabaseKey, Database> DatabaseCache;
diff --git a/src/clblast.cpp b/src/clblast.cpp
index d44649bb..3983e5fc 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2492,11 +2492,12 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
// Retrieves the device name
const auto device_cpp = Device(device);
- const auto device_name = device_cpp.Name();
+ const auto platform_id = device_cpp.Platform();
+ const auto device_name = GetDeviceName(device_cpp);
// Retrieves the current database values to verify whether the new ones are complete
auto in_cache = false;
- const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision, device_name, kernel_name }, &in_cache);
+ const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache);
if (!in_cache) { return StatusCode::kInvalidOverrideKernel; }
for (const auto &current_param : current_database.GetParameterNames()) {
if (parameters.find(current_param) == parameters.end()) {
@@ -2530,8 +2531,8 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
const auto database = Database(device_cpp, kernel_name, precision, database_entries);
// Removes the old database entry and stores the new one in the cache
- DatabaseCache::Instance().Remove(DatabaseKey{ precision, device_name, kernel_name });
- DatabaseCache::Instance().Store(DatabaseKey{ precision, device_name, kernel_name }, Database(database));
+ DatabaseCache::Instance().Remove(DatabaseKey{platform_id, device, precision, kernel_name});
+ DatabaseCache::Instance().Store(DatabaseKey{platform_id, device, precision, kernel_name}, Database(database));
} catch (...) { return DispatchException(); }
return StatusCode::kSuccess;
diff --git a/src/clpp11.hpp b/src/clpp11.hpp
index 7c1457b0..7d348e18 100644
--- a/src/clpp11.hpp
+++ b/src/clpp11.hpp
@@ -230,6 +230,7 @@ class Device {
}
// Methods to retrieve device information
+ cl_platform_id Platform() const { return GetInfo<cl_platform_id>(CL_DEVICE_PLATFORM); }
std::string Version() const { return GetInfoString(CL_DEVICE_VERSION); }
size_t VersionNumber() const
{
diff --git a/src/routine.cpp b/src/routine.cpp
index 758ffa0c..c305feb8 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -60,7 +60,7 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
event_(event),
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
- device_name_(device_.Name()),
+ platform_(device_.Platform()),
db_(kernel_names) {
InitDatabase(userDatabase);
@@ -72,13 +72,13 @@ void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatab
// Queries the cache to see whether or not the kernel parameter database is already there
bool has_db;
- db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, kernel_name },
+ db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_, device_(), precision_, kernel_name },
&has_db);
if (has_db) { continue; }
// Builds the parameter database for this device and routine set and stores it in the cache
db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
- DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, kernel_name },
+ DatabaseCache::Instance().Store(DatabaseKey{ platform_, device_(), precision_, kernel_name },
Database{ db_(kernel_name) });
}
}
@@ -100,8 +100,9 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
// Queries the cache to see whether or not the binary (device-specific) is already there. If it
// is, a program is created and stored in the cache
+ const auto device_name = GetDeviceName(device_);
bool has_binary;
- auto binary = BinaryCache::Instance().Get(BinaryKeyRef{ precision_, routine_name_, device_name_ },
+ auto binary = BinaryCache::Instance().Get(BinaryKeyRef{ precision_, routine_name_, device_name },
&has_binary);
if (has_binary) {
program_ = Program(device_, context_, binary);
@@ -171,7 +172,7 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
// Prints details of the routine to compile in case of debugging in verbose mode
#ifdef VERBOSE
printf("[DEBUG] Compiling routine '%s-%s' for device '%s'\n",
- routine_name_.c_str(), ToString(precision_).c_str(), device_name_.c_str());
+ routine_name_.c_str(), ToString(precision_).c_str(), device_name.c_str());
const auto start_time = std::chrono::steady_clock::now();
#endif
@@ -188,7 +189,7 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
}
// Store the compiled binary and program in the cache
- BinaryCache::Instance().Store(BinaryKey{ precision_, routine_name_, device_name_ },
+ BinaryCache::Instance().Store(BinaryKey{ precision_, routine_name_, device_name },
program_.GetIR());
ProgramCache::Instance().Store(ProgramKey{ context_(), device_(), precision_, routine_name_ },
diff --git a/src/routine.hpp b/src/routine.hpp
index 5e2b4065..e77e35ad 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -75,9 +75,7 @@ class Routine {
EventPointer event_;
const Context context_;
const Device device_;
-
- // OpenCL device properties
- const std::string device_name_;
+ const cl_platform_id platform_;
// Compiled program (either retrieved from cache or compiled in slow path)
Program program_;