diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-05-17 09:23:28 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-05-17 09:23:28 +0100 |
commit | e057a9186a1ed0a169fcf4db7a2598d08f530834 (patch) | |
tree | ca9a9f8283e50c9676d4c167390ce0a90eb05793 /src/kernels | |
parent | 0cb95800424273e56740d3b23cef53e740eab9b5 (diff) |
First version of direct reading from image tensor for convgemm: only for edge cases now
Diffstat (limited to 'src/kernels')
-rw-r--r-- | src/kernels/level3/xconvgemm.opencl | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/src/kernels/level3/xconvgemm.opencl b/src/kernels/level3/xconvgemm.opencl index d3c53d7d..cddb6785 100644 --- a/src/kernels/level3/xconvgemm.opencl +++ b/src/kernels/level3/xconvgemm.opencl @@ -19,15 +19,52 @@ R"( // ================================================================================================= #if defined(ROUTINE_CONVGEMM) +// Loads global off-chip memory into thread-private register files. This function is specific for +// loading the image input tensor. This includes a bounds check. +INLINE_FUNC real GlobalToPrivateCheckedImage(const __global real* restrict imagegm, const int image_offset_batch, + const int h_id, const int w_id, const int kwg, + const int input_h, const int input_w, const int channels, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w) { + real result; + + const int kernel_2d_index = kwg % (kernel_h * kernel_w); + const int kw_id = kernel_2d_index % kernel_w; + const int kh_id = kernel_2d_index / kernel_w; + const int c_id = kwg / (kernel_h * kernel_w); + + const int h_index = -pad_h + kh_id * dilation_h + stride_h * h_id; + const int w_index = -pad_w + kw_id * dilation_w + stride_w * w_id; + if (h_index >= 0 && h_index < input_h && + w_index >= 0 && w_index < input_w) { + const int image_index = w_index + input_w * (h_index + input_h * c_id); + result = imagegm[image_index + image_offset_batch]; + } + else { + SetToZero(result); + } + return result; +} + // ConvGEMM kernel __kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1))) void Xconvgemm(const int num_patches, const int num_kernels, const int patch_size, const __global realMD* restrict colgm, const int col_offset, const int col_stride, const __global realND* restrict kernelgm, const int kernel_offset, - __global real* resultgm, const int result_offset, const int result_stride) { + __global real* resultgm, const int result_offset, const int result_stride, + const int input_h, const int input_w, const int channels, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const __global realMD* restrict imagegm, const int image_offset, + const int output_h, const int output_w) { // Batch offsets const int batch = get_group_id(2); + const int image_offset_batch = image_offset + channels * input_h * input_w * batch; const int col_offset_batch = col_offset + col_stride * batch; const int result_offset_batch = result_offset + result_stride * batch; @@ -59,6 +96,8 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz // processes only the main parts: output blocks of WGD by WGD. const int idm = get_local_id(0) * MWID + GetGroupID0() * WGD; const int idn = get_local_id(1) * NWID + GetGroupID1() * WGD; + const int w_id = idm % output_w; + const int h_id = idm / output_w; if ((idm < (num_patches/WGD)*WGD) && (idn < (num_kernels/WGD)*WGD)) { // Loops over all complete workgroup tiles (K-dimension) @@ -190,7 +229,10 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz // Loads data: off-chip --> private (matrix A and B) #pragma unroll for (int _mi = 0; _mi < MWID; _mi += 1) { - apd[_mi] = GlobalToPrivateCheckedA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false, num_patches); + apd[_mi] = GlobalToPrivateCheckedImage(imagegm, image_offset_batch, h_id, w_id, kwg, + input_h, input_w, channels, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w); } #pragma unroll for (int _ni = 0; _ni < NWID; _ni += 1) { |