diff options
Diffstat (limited to 'src/routine.cpp')
-rw-r--r-- | src/routine.cpp | 71 |
1 files changed, 48 insertions, 23 deletions
diff --git a/src/routine.cpp b/src/routine.cpp index acafb0d2..4fe04a60 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -32,11 +32,34 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, event_(event), context_(queue_.GetContext()), device_(queue_.GetDevice()), - device_name_(device_.Name()), - db_(queue_, routines, precision_, userDatabase) { + device_name_(device_.Name()) { + + InitDatabase(routines, userDatabase); + InitProgram(source); +} + +void Routine::InitDatabase(const std::vector<std::string> &routines, + const std::vector<const Database::DatabaseEntry*> &userDatabase) { + + // Queries the cache to see whether or not the kernel parameter database is already there + bool has_db; + db_ = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, routines }, + &has_db); + if (has_db) { return; } + + // Builds the parameter database for this device and routine set and stores it in the cache + db_ = Database(device_, routines, precision_, userDatabase); + DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, routines }, + Database{ db_ }); +} + +void Routine::InitProgram(std::initializer_list<const char *> source) { // Queries the cache to see whether or not the program (context-specific) is already there - if (ProgramIsInCache(context_, precision_, routine_name_)) { return; } + bool has_program; + program_ = ProgramCache::Instance().Get(ProgramKeyRef{ context_(), precision_, routine_name_ }, + &has_program); + if (has_program) { return; } // Sets the build options from an environmental variable (if set) auto options = std::vector<std::string>(); @@ -47,29 +70,29 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, // Queries the cache to see whether or not the binary (device-specific) is already there. If it // is, a program is created and stored in the cache - if (BinaryIsInCache(device_name_, precision_, routine_name_)) { - auto& binary = GetBinaryFromCache(device_name_, precision_, routine_name_); - auto program = Program(device_, context_, binary); - program.Build(device_, options); - StoreProgramToCache(program, context_, precision_, routine_name_); + bool has_binary; + auto binary = BinaryCache::Instance().Get(BinaryKeyRef{ precision_, routine_name_, device_name_ }, + &has_binary); + if (has_binary) { + program_ = Program(device_, context_, binary); + program_.Build(device_, options); + ProgramCache::Instance().Store(ProgramKey{ context_(), precision_, routine_name_ }, + Program{ program_ }); + return; } // Otherwise, the kernel will be compiled and program will be built. Both the binary and the // program will be added to the cache. // Inspects whether or not cl_khr_fp64 is supported in case of double precision - const auto extensions = device_.Capabilities(); - if (precision_ == Precision::kDouble || precision_ == Precision::kComplexDouble) { - if (extensions.find(kKhronosDoublePrecision) == std::string::npos) { - throw RuntimeErrorCode(StatusCode::kNoDoublePrecision); - } + if ((precision_ == Precision::kDouble && !PrecisionSupported<double>(device_)) || + (precision_ == Precision::kComplexDouble && !PrecisionSupported<double2>(device_))) { + throw RuntimeErrorCode(StatusCode::kNoDoublePrecision); } // As above, but for cl_khr_fp16 (half precision) - if (precision_ == Precision::kHalf) { - if (extensions.find(kKhronosHalfPrecision) == std::string::npos) { - throw RuntimeErrorCode(StatusCode::kNoHalfPrecision); - } + if (precision_ == Precision::kHalf && !PrecisionSupported<half>(device_)) { + throw RuntimeErrorCode(StatusCode::kNoHalfPrecision); } // Collects the parameters for this device in the form of defines, and adds the precision @@ -114,21 +137,23 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, #endif // Compiles the kernel - auto program = Program(context_, source_string); + program_ = Program(context_, source_string); try { - program.Build(device_, options); + program_.Build(device_, options); } catch (const CLError &e) { if (e.status() == CL_BUILD_PROGRAM_FAILURE) { fprintf(stdout, "OpenCL compiler error/warning: %s\n", - program.GetBuildInfo(device_).c_str()); + program_.GetBuildInfo(device_).c_str()); } throw; } // Store the compiled binary and program in the cache - const auto binary = program.GetIR(); - StoreBinaryToCache(binary, device_name_, precision_, routine_name_); - StoreProgramToCache(program, context_, precision_, routine_name_); + BinaryCache::Instance().Store(BinaryKey{ precision_, routine_name_, device_name_ }, + program_.GetIR()); + + ProgramCache::Instance().Store(ProgramKey{ context_(), precision_, routine_name_ }, + Program{ program_ }); // Prints the elapsed compilation time in case of debugging in verbose mode #ifdef VERBOSE |