From a45992010591bfbf46fdc99496e68982cad163b9 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sat, 1 Oct 2016 16:58:53 +0200 Subject: Added padding to the local memory of the GEMM direct kernel --- src/kernels/level3/xgemm_direct.opencl | 173 ++++++++++++++++++--------------- 1 file changed, 94 insertions(+), 79 deletions(-) (limited to 'src/kernels/level3') diff --git a/src/kernels/level3/xgemm_direct.opencl b/src/kernels/level3/xgemm_direct.opencl index 705ced9c..75618e8c 100644 --- a/src/kernels/level3/xgemm_direct.opencl +++ b/src/kernels/level3/xgemm_direct.opencl @@ -43,6 +43,12 @@ R"( #ifndef VWND #define VWND 1 // Vector width of matrix B #endif +#ifndef PADA + #define PADA 1 // Local memory padding for matrix A +#endif +#ifndef PADB + #define PADB 1 // Local memory padding for matrix B +#endif // Helper parameters based on the above tuning parameters #define MWID (WGD/MDIMCD) // Work per work-item (M-dimension) @@ -87,10 +93,16 @@ R"( // Caches global off-chip memory into local (shared) memory on-chip. This function is specific for // caching the A input matrix. inline void GlobalToLocalDirectA(const __global realMD* restrict agm, __local real* alm, - const int a_ld, const int a_offset, const int tid, const int kwg, + const int a_ld, const int a_offset, const int kwg, const int a_transpose, const int a_conjugate) { - const int la0 = tid % MDIMAD; - const int la1 = tid / MDIMAD; + #if MDIMCD == MDIMAD + const int la0 = get_local_id(0); + const int la1 = get_local_id(1); + #else + const int tid = get_local_id(0) + MDIMCD*get_local_id(1); + const int la0 = tid % MDIMAD; + const int la1 = tid / MDIMAD; + #endif #pragma unroll for (int mia=0; mia local (matrix A and B) - GlobalToLocalDirectA(agm, alm, a_ld, a_offset, tid, kwg, a_transpose, a_conjugate); - GlobalToLocalDirectB(bgm, blm, b_ld, b_offset, tid, kwg, b_transpose, b_conjugate); + GlobalToLocalDirectA(agm, alm, a_ld, a_offset, kwg, a_transpose, a_conjugate); + GlobalToLocalDirectB(bgm, blm, b_ld, b_offset, kwg, b_transpose, b_conjugate); barrier(CLK_LOCAL_MEM_FENCE); // Loops over all workitem tiles, unrolled by a factor KWID -- cgit v1.2.3