diff options
Diffstat (limited to 'src/utilities')
-rw-r--r-- | src/utilities/compile.cpp | 14 | ||||
-rw-r--r-- | src/utilities/device_mapping.hpp | 5 | ||||
-rw-r--r-- | src/utilities/utilities.cpp | 8 |
3 files changed, 27 insertions, 0 deletions
diff --git a/src/utilities/compile.cpp b/src/utilities/compile.cpp index 05c29944..835f54b4 100644 --- a/src/utilities/compile.cpp +++ b/src/utilities/compile.cpp @@ -61,8 +61,22 @@ std::shared_ptr<Program> CompileFromSource( // For Intel GPUs with subgroup support, use subgroup shuffling. if (device.IsGPU() && device.HasExtension(kKhronosIntelSubgroups)) { header_string += "#define USE_SUBGROUP_SHUFFLING 1\n"; + header_string += "#define SUBGROUP_SHUFFLING_INTEL 1\n"; } + // For NVIDIA GPUs, inline PTX can provide subgroup support + if (device.IsGPU() && device.IsNVIDIA() && precision == Precision::kSingle) { + header_string += "#define USE_SUBGROUP_SHUFFLING 1\n"; + + // Nvidia needs to check pre or post volta due to new shuffle commands + if (device.IsPostNVIDIAVolta()) { + header_string += "#define SUBGROUP_SHUFFLING_NVIDIA_POST_VOLTA 1\n"; + } + else { + header_string += "#define SUBGROUP_SHUFFLING_NVIDIA_PRE_VOLTA 1\n"; + } + } + // Optionally adds a translation header from OpenCL kernels to CUDA kernels #ifdef CUDA_API header_string += diff --git a/src/utilities/device_mapping.hpp b/src/utilities/device_mapping.hpp index 7fdc04a0..c814622f 100644 --- a/src/utilities/device_mapping.hpp +++ b/src/utilities/device_mapping.hpp @@ -43,6 +43,11 @@ const std::unordered_map<std::string, std::string> kDeviceNames { // Empty }; +// Things to remove from device names (low-level) +const std::vector<std::string> kDeviceRemovals { + "pthread-" +}; + // ================================================================================================= } // namespace device_mapping } // namespace clblast diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index 2008b6a3..a8fdaa19 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -477,6 +477,14 @@ std::string GetDeviceName(const Device& device) { for (auto &find_and_replace : device_mapping::kDeviceNames) { // replacing to common names if (device_name == find_and_replace.first) { device_name = find_and_replace.second; } } + + for (auto &removal : device_mapping::kDeviceRemovals) { // removing certain things + if (device_name.find(removal) != std::string::npos) { + auto start_position_to_erase = device_name.find(removal); + device_name.erase(start_position_to_erase, removal.length()); + } + } + return device_name; } |