summaryrefslogtreecommitdiff
path: root/src/utilities
diff options
context:
space:
mode:
Diffstat (limited to 'src/utilities')
-rw-r--r--src/utilities/compile.cpp14
1 files changed, 14 insertions, 0 deletions
diff --git a/src/utilities/compile.cpp b/src/utilities/compile.cpp
index 05c29944..835f54b4 100644
--- a/src/utilities/compile.cpp
+++ b/src/utilities/compile.cpp
@@ -61,8 +61,22 @@ std::shared_ptr<Program> CompileFromSource(
// For Intel GPUs with subgroup support, use subgroup shuffling.
if (device.IsGPU() && device.HasExtension(kKhronosIntelSubgroups)) {
header_string += "#define USE_SUBGROUP_SHUFFLING 1\n";
+ header_string += "#define SUBGROUP_SHUFFLING_INTEL 1\n";
}
+ // For NVIDIA GPUs, inline PTX can provide subgroup support
+ if (device.IsGPU() && device.IsNVIDIA() && precision == Precision::kSingle) {
+ header_string += "#define USE_SUBGROUP_SHUFFLING 1\n";
+
+ // Nvidia needs to check pre or post volta due to new shuffle commands
+ if (device.IsPostNVIDIAVolta()) {
+ header_string += "#define SUBGROUP_SHUFFLING_NVIDIA_POST_VOLTA 1\n";
+ }
+ else {
+ header_string += "#define SUBGROUP_SHUFFLING_NVIDIA_PRE_VOLTA 1\n";
+ }
+ }
+
// Optionally adds a translation header from OpenCL kernels to CUDA kernels
#ifdef CUDA_API
header_string +=