summaryrefslogtreecommitdiff
path: root/src/kernels/level3/xgemm_direct_part3.opencl
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernels/level3/xgemm_direct_part3.opencl')
-rw-r--r--src/kernels/level3/xgemm_direct_part3.opencl18
1 files changed, 13 insertions, 5 deletions
diff --git a/src/kernels/level3/xgemm_direct_part3.opencl b/src/kernels/level3/xgemm_direct_part3.opencl
index 14ed8223..a9350e00 100644
--- a/src/kernels/level3/xgemm_direct_part3.opencl
+++ b/src/kernels/level3/xgemm_direct_part3.opencl
@@ -46,17 +46,25 @@ inline void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
// processes only the main parts: output blocks of WGD by WGD.
const int idm = get_local_id(0) * MWID + GetGroupID0() * WGD;
const int idn = get_local_id(1) * NWID + GetGroupID1() * WGD;
-
- if ((idm < (kSizeM/WGD)*WGD) && (idn < (kSizeN/WGD)*WGD) &&
- (a_ld % VWMD == 0) && (b_ld % VWND == 0)) {
+ if ((idm < (kSizeM/WGD)*WGD) && (idn < (kSizeN/WGD)*WGD)) {
// Loops over all complete workgroup tiles (K-dimension)
int kwg = 0;
for (; kwg < (kSizeK/WGD) * WGD; kwg+=WGD) {
// Loads data: off-chip --> local (matrix A and B)
- GlobalToLocalDirectA(agm, alm, a_ld, a_offset, kwg, a_transpose, a_conjugate);
- GlobalToLocalDirectB(bgm, blm, b_ld, b_offset, kwg, b_transpose, b_conjugate);
+ if (a_ld % VWMD == 0) {
+ GlobalToLocalDirectA(agm, alm, a_ld, a_offset, kwg, a_transpose, a_conjugate);
+ }
+ else {
+ GlobalToLocalScalarA(agms, alm, a_ld, a_offset, kwg, a_transpose, a_conjugate);
+ }
+ if (b_ld % VWND == 0) {
+ GlobalToLocalDirectB(bgm, blm, b_ld, b_offset, kwg, b_transpose, b_conjugate);
+ }
+ else {
+ GlobalToLocalScalarB(bgms, blm, b_ld, b_offset, kwg, b_transpose, b_conjugate);
+ }
barrier(CLK_LOCAL_MEM_FENCE);
// Loops over all workitem tiles, unrolled by a factor KWID