summaryrefslogtreecommitdiff
path: root/src/kernels/common.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/common.opencl')
-rw-r--r--src/kernels/common.opencl33
1 files changed, 20 insertions, 13 deletions
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl
index 9481881e..01c411bc 100644
--- a/src/kernels/common.opencl
+++ b/src/kernels/common.opencl
@@ -24,14 +24,16 @@ R"(
// =================================================================================================
-// Enable support for double-precision
-#if PRECISION == 16
- #pragma OPENCL EXTENSION cl_khr_fp16: enable
-#endif
+#ifndef CUDA
+ // Enable support for double-precision
+ #if PRECISION == 16
+ #pragma OPENCL EXTENSION cl_khr_fp16: enable
+ #endif
-// Enable support for double-precision
-#if PRECISION == 64 || PRECISION == 6464
- #pragma OPENCL EXTENSION cl_khr_fp64: enable
+ // Enable support for double-precision
+ #if PRECISION == 64 || PRECISION == 6464
+ #pragma OPENCL EXTENSION cl_khr_fp64: enable
+ #endif
#endif
// Half-precision
@@ -117,10 +119,15 @@ R"(
#define GetRealArg(x) x
#endif
+// Pointers to local memory objects (using a define because CUDA doesn't need them)
+#ifndef LOCAL_PTR
+ #define LOCAL_PTR __local
+#endif
+
// =================================================================================================
// Don't use the non-IEEE754 compliant OpenCL built-in mad() instruction per default. For specific
-// devices, this is enabled (see src/routine.cc).
+// devices, this is enabled (see src/routine.cpp).
#ifndef USE_CL_MAD
#define USE_CL_MAD 0
#endif
@@ -254,18 +261,18 @@ R"(
// http://docs.nvidia.com/cuda/samples/6_Advanced/transpose/doc/MatrixTranspose.pdf
// More details: https://github.com/CNugteren/CLBlast/issues/53
#if USE_STAGGERED_INDICES == 1
- INLINE_FUNC size_t GetGroupIDFlat() {
+ INLINE_FUNC int GetGroupIDFlat() {
return get_group_id(0) + get_num_groups(0) * get_group_id(1);
}
- INLINE_FUNC size_t GetGroupID1() {
+ INLINE_FUNC int GetGroupID1() {
return (GetGroupIDFlat()) % get_num_groups(1);
}
- INLINE_FUNC size_t GetGroupID0() {
+ INLINE_FUNC int GetGroupID0() {
return ((GetGroupIDFlat() / get_num_groups(1)) + GetGroupID1()) % get_num_groups(0);
}
#else
- INLINE_FUNC size_t GetGroupID1() { return get_group_id(1); }
- INLINE_FUNC size_t GetGroupID0() { return get_group_id(0); }
+ INLINE_FUNC int GetGroupID1() { return get_group_id(1); }
+ INLINE_FUNC int GetGroupID0() { return get_group_id(0); }
#endif
// =================================================================================================