summaryrefslogtreecommitdiff
path: root/src/kernels/level3/xgemm_part1.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/level3/xgemm_part1.opencl')
-rw-r--r--src/kernels/level3/xgemm_part1.opencl6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/kernels/level3/xgemm_part1.opencl b/src/kernels/level3/xgemm_part1.opencl
index cbc43d51..d15dafc8 100644
--- a/src/kernels/level3/xgemm_part1.opencl
+++ b/src/kernels/level3/xgemm_part1.opencl
@@ -298,11 +298,12 @@ INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int
// is specific for caching the A input matrix for kernel 1.
INLINE_FUNC realN GlobalToPrivateA2D(const __global real* restrict a_ptr, const int tid_y, const int _ni,
const int kSizeK, const int idk, const int _ki) {
- const int a_index = (tid_y * NWI + _ni) * kSizeK + idk + _ki * VWN;
#if PRECISION == 3232 || PRECISION == 6464
+ const int a_index = (tid_y * NWI + _ni) * (kSizeK / VWN) + idk / VWN + _ki;
const __global realN* restrict agm = (const __global realN* restrict) a_ptr;
return agm[a_index];
#else
+ const int a_index = (tid_y * NWI + _ni) * kSizeK + idk + _ki * VWN;
#if VWN == 1
return a_ptr[a_index];
#elif VWN == 2
@@ -320,11 +321,12 @@ INLINE_FUNC realN GlobalToPrivateA2D(const __global real* restrict a_ptr, const
// Same as above, but now for the B input matrix
INLINE_FUNC realM GlobalToPrivateB2D(const __global real* restrict b_ptr, const int tid_x, const int _mi,
const int kSizeN, const int idk, const int _ki) {
- const int b_index = (idk + _ki) * kSizeN + tid_x * MWI + _mi * VWM;
#if PRECISION == 3232 || PRECISION == 6464
+ const int b_index = (idk + _ki) * (kSizeN / VWM) + tid_x * (MWI / VWM) + _mi;
const __global realM* restrict bgm = (const __global realM* restrict) b_ptr;
return bgm[b_index];
#else
+ const int b_index = (idk + _ki) * kSizeN + tid_x * MWI + _mi * VWM;
#if VWM == 1
return b_ptr[b_index];
#elif VWM == 2