summaryrefslogtreecommitdiff
path: root/src/utilities/compile.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/utilities/compile.cpp')
-rw-r--r--src/utilities/compile.cpp22
1 files changed, 19 insertions, 3 deletions
diff --git a/src/utilities/compile.cpp b/src/utilities/compile.cpp
index 05c29944..cd0b3d2b 100644
--- a/src/utilities/compile.cpp
+++ b/src/utilities/compile.cpp
@@ -58,11 +58,27 @@ std::shared_ptr<Program> CompileFromSource(
header_string += "#define GLOBAL_MEM_FENCE 1\n";
}
- // For Intel GPUs with subgroup support, use subgroup shuffling.
- if (device.IsGPU() && device.HasExtension(kKhronosIntelSubgroups)) {
+ // For GPUs with subgroup support, use subgroup shuffling.
+ // Currently these are Intel via an extension and Nvidia using inline PTX (restricted to 32 bit)
+ if (device.IsGPU() && (device.HasExtension(kKhronosIntelSubgroups) ||
+ (device.IsNVIDIA() && static_cast<int>(precision) == 32))) {
header_string += "#define USE_SUBGROUP_SHUFFLING 1\n";
- }
+ // Define the flavor of subgroup
+ if (device.IsNVIDIA()) {
+ header_string += "#define NVIDIA_WARPS_AS_SUBGROUPS 1\n";
+
+ // Nvidia additionally needs to check pre or post volta due to new
+ // shuffle commands
+ if (device.IsPostNVIDIAVolta()) {
+ header_string += "#define NVIDIA_POST_VOLTA 1\n";
+ }
+ }
+ else if (device.HasExtension(kKhronosIntelSubgroups)) {
+ header_string += "#define INTEL_SUBGROUP_EXTENSION 1\n";
+ }
+ }
+
// Optionally adds a translation header from OpenCL kernels to CUDA kernels
#ifdef CUDA_API
header_string +=