diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/kernel_preprocessor.cpp | 155 | ||||
-rw-r--r-- | src/kernels/level2/xgemv.opencl | 14 | ||||
-rw-r--r-- | src/kernels/level2/xgemv_fast.opencl | 76 |
3 files changed, 165 insertions, 80 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp index f0245ce6..6239361b 100644 --- a/src/kernel_preprocessor.cpp +++ b/src/kernel_preprocessor.cpp @@ -18,6 +18,8 @@ // The above also requires the spaces in that exact form // - The loop variable should be a unique string within the code in the for-loop body (e.g. don't // use 'i' or 'w' but rather '_w' or a longer name. +// - The pragma "#pragma promote_to_registers" unrolls an array into multiple scalar values. The +// name of this scalar should be unique (see above). // // ================================================================================================= @@ -176,26 +178,85 @@ bool EvaluateCondition(std::string condition, // Array to register promotion, e.g. arr[w] to {arr_0, arr_1} void ArrayToRegister(std::string &source_line, const std::unordered_map<std::string, int>& defines, - const std::unordered_map<std::string, size_t>& arrays_to_registers) { + const std::unordered_map<std::string, size_t>& arrays_to_registers, + const size_t num_brackets) { for (const auto array_name_map : arrays_to_registers) { // only if marked to be promoted - auto array_pos = source_line.find(array_name_map.first + "["); - while (array_pos != std::string::npos) { - - // Retrieves the array index - const auto loop_remainder = source_line.substr(array_pos); - const auto loop_split = split(split(loop_remainder, '[')[1], ']'); - if (loop_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration"); } - auto array_index_string = loop_split[0]; - - // Replaces the array with a register value - SubstituteDefines(defines, array_index_string); - const auto array_index = StringToDigit(array_index_string, source_line); - FindReplace(source_line, array_name_map.first + "[" + loop_split[0] + "]", - array_name_map.first + "_" + ToString(array_index)); - - // Performs an extra substitution if this array occurs another time in this line - array_pos = source_line.find(array_name_map.first + "["); + + // Outside of a function + if (num_brackets == 0) { + + // Case 1: argument in a function declaration (e.g. 'void func(const float arr[2])') + const auto array_pos = source_line.find(array_name_map.first + "["); + if (array_pos != std::string::npos) { + SubstituteDefines(defines, source_line); + + // Finds the full array declaration (e.g. 'const float arr[2]') + const auto left_split = split(source_line, '('); + auto arguments = left_split.size() >= 2 ? left_split[1] : source_line; + const auto right_split = split(arguments, ')'); + arguments = right_split.size() >= 1 ? right_split[0] : arguments; + const auto comma_split = split(arguments, ','); + for (auto j = size_t{0}; j < comma_split.size(); ++j) { + if (comma_split[j].find(array_name_map.first + "[") != std::string::npos) { + + // Retrieves the array index + const auto left_square_split = split(comma_split[j], '['); + if (left_square_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration #A"); } + const auto right_square_split = split(left_square_split[1], ']'); + if (right_square_split.size() < 1) { RaiseError(source_line, "Mis-formatted array declaration #B"); } + auto array_index_string = right_square_split[0]; + const auto array_index = StringToDigit(array_index_string, source_line); + + // Creates the new string + auto replacement = std::string{}; + for (auto index = size_t{0}; index < array_index; ++index) { + replacement += left_square_split[0] + "_" + ToString(index); + if (index != array_index - 1) { replacement += ","; } + } + + // Performs the actual replacement + FindReplace(source_line, comma_split[j], replacement); + } + } + } + } + + // Inside a function + else { + auto array_pos = source_line.find(array_name_map.first + "["); + + // Case 2: passed to another function (e.g. 'func(arr)') + if (array_pos == std::string::npos) { // assumes case 2 and case 3 (below) cannot occur in one line + auto bracket_split = split(source_line, '('); + if (bracket_split.size() >= 2) { + auto replacement = std::string{}; + for (auto i = size_t{0}; i < array_name_map.second; ++i) { + replacement += array_name_map.first + "_" + ToString(i); + if (i != array_name_map.second - 1) { replacement += ", "; } + } + FindReplace(source_line, array_name_map.first, replacement); + } + } + + // Case 2: used as an array (e.g. 'arr[w]') + while (array_pos != std::string::npos) { + + // Retrieves the array index + const auto loop_remainder = source_line.substr(array_pos); + const auto loop_split = split(split(loop_remainder, '[')[1], ']'); + if (loop_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration #C"); } + auto array_index_string = loop_split[0]; + + // Replaces the array with a register value + SubstituteDefines(defines, array_index_string); + const auto array_index = StringToDigit(array_index_string, source_line); + FindReplace(source_line, array_name_map.first + "[" + loop_split[0] + "]", + array_name_map.first + "_" + ToString(array_index)); + + // Performs an extra substitution if this array occurs another time in this line + array_pos = source_line.find(array_name_map.first + "["); + } } } } @@ -319,31 +380,21 @@ std::vector<std::string> PreprocessDefinesAndComments(const std::string& source, // ================================================================================================= -// Second pass: unroll loops +// Second pass: detect array-to-register promotion pragma's and replace declarations & function calls std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines, const std::unordered_map<std::string, int>& defines, - std::unordered_map<std::string, size_t>& arrays_to_registers, - const bool array_to_register_promotion) { + std::unordered_map<std::string, size_t>& arrays_to_registers) { auto lines = std::vector<std::string>(); auto brackets = size_t{0}; - auto unroll_next_loop = false; auto promote_next_array_to_registers = false; for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) { auto line = source_lines[line_id]; // Detect #pragma promote_to_registers directives (unofficial pragma) - if (array_to_register_promotion) { - if (line.find("#pragma promote_to_registers") != std::string::npos) { - promote_next_array_to_registers = true; - continue; - } - } - - // Detect #pragma unroll directives - if (line.find("#pragma unroll") != std::string::npos) { - unroll_next_loop = true; + if (line.find("#pragma promote_to_registers") != std::string::npos) { + promote_next_array_to_registers = true; continue; } @@ -369,10 +420,43 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s const auto array_name_split = split(line_split1[0], ' '); if (array_name_split.size() < 2) { RaiseError(line, "Mis-formatted array declaration #2"); } const auto array_name = array_name_split[array_name_split.size() - 1]; - arrays_to_registers[array_name] = brackets; // TODO: bracket count not used currently for scope checking + arrays_to_registers[array_name] = array_size; + // TODO: bracket count not used currently for scope checking continue; } + // Regular line + lines.emplace_back(line); + } + return lines; +} + +// ================================================================================================= + +// Third pass: unroll loops and perform actual array-to-register promotion +std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines, + const std::unordered_map<std::string, int>& defines, + std::unordered_map<std::string, size_t>& arrays_to_registers, + const bool array_to_register_promotion) { + auto lines = std::vector<std::string>(); + + auto brackets = size_t{0}; + auto unroll_next_loop = false; + + for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) { + auto line = source_lines[line_id]; + + // Detect #pragma unroll directives + if (line.find("#pragma unroll") != std::string::npos) { + unroll_next_loop = true; + continue; + } + + // Brackets + const auto num_brackets_before = brackets; + brackets += std::count(line.begin(), line.end(), '{'); + brackets -= std::count(line.begin(), line.end(), '}'); + // Loop unrolling assuming it to be in the form "for (int w = 0; w < 4; w += 1) {" if (unroll_next_loop) { unroll_next_loop = false; @@ -427,7 +511,7 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s // Array to register promotion if (array_to_register_promotion) { - ArrayToRegister(loop_line, defines, arrays_to_registers); + ArrayToRegister(loop_line, defines, arrays_to_registers, num_brackets_before); } lines.emplace_back(loop_line); @@ -440,7 +524,7 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s // Array to register promotion if (array_to_register_promotion) { - ArrayToRegister(line, defines, arrays_to_registers); + ArrayToRegister(line, defines, arrays_to_registers, num_brackets_before); } lines.emplace_back(line); @@ -459,6 +543,7 @@ std::string PreprocessKernelSource(const std::string& kernel_source) { // Unrolls loops (single level each call) auto arrays_to_registers = std::unordered_map<std::string, size_t>(); + lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, true); diff --git a/src/kernels/level2/xgemv.opencl b/src/kernels/level2/xgemv.opencl index 2a50e8fb..ba29aba6 100644 --- a/src/kernels/level2/xgemv.opencl +++ b/src/kernels/level2/xgemv.opencl @@ -228,10 +228,10 @@ void Xgemv(const int m, const int n, // Initializes the accumulation register #pragma promote_to_registers - real acc[WPT1]; + real acc1[WPT1]; #pragma unroll for (int _w = 0; _w < WPT1; _w += 1) { - SetToZero(acc[_w]); + SetToZero(acc1[_w]); } // Divides the work in a main and tail section @@ -262,7 +262,7 @@ void Xgemv(const int m, const int n, const int k = kwg + kloop + _kunroll; real value = LoadMatrixA(agm, gid, k, a_ld, a_offset, parameter, kl, ku); if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } - MultiplyAdd(acc[_w], xlm[kloop + _kunroll], value); + MultiplyAdd(acc1[_w], xlm[kloop + _kunroll], value); } } } @@ -273,7 +273,7 @@ void Xgemv(const int m, const int n, const int k = kwg + kloop + _kunroll; real value = LoadMatrixA(agm, k, gid, a_ld, a_offset, parameter, kl, ku); if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } - MultiplyAdd(acc[_w], xlm[kloop + _kunroll], value); + MultiplyAdd(acc1[_w], xlm[kloop + _kunroll], value); } } } @@ -295,20 +295,20 @@ void Xgemv(const int m, const int n, for (int k=n_floor; k<n; ++k) { real value = LoadMatrixA(agm, gid, k, a_ld, a_offset, parameter, kl, ku); if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } - MultiplyAdd(acc[_w], xgm[k*x_inc + x_offset], value); + MultiplyAdd(acc1[_w], xgm[k*x_inc + x_offset], value); } } else { // Transposed for (int k=n_floor; k<n; ++k) { real value = LoadMatrixA(agm, k, gid, a_ld, a_offset, parameter, kl, ku); if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } - MultiplyAdd(acc[_w], xgm[k*x_inc + x_offset], value); + MultiplyAdd(acc1[_w], xgm[k*x_inc + x_offset], value); } } // Stores the final result real yval = ygm[gid*y_inc + y_offset]; - AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[_w], beta, yval); + AXPBY(ygm[gid*y_inc + y_offset], alpha, acc1[_w], beta, yval); } } } diff --git a/src/kernels/level2/xgemv_fast.opencl b/src/kernels/level2/xgemv_fast.opencl index 892bc55c..45ceb36c 100644 --- a/src/kernels/level2/xgemv_fast.opencl +++ b/src/kernels/level2/xgemv_fast.opencl @@ -106,10 +106,10 @@ void XgemvFast(const int m, const int n, // Initializes the accumulation registers #pragma promote_to_registers - real acc[WPT2]; + real acc2[WPT2]; #pragma unroll for (int _w = 0; _w < WPT2; _w += 1) { - SetToZero(acc[_w]); + SetToZero(acc2[_w]); } // Loops over work-group sized portions of the work @@ -131,41 +131,41 @@ void XgemvFast(const int m, const int n, const int gid = (WPT2/VW2)*get_global_id(0) + _w; realVF avec = agm[(a_ld/VW2)*k + gid]; #if VW2 == 1 - MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec); + MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec); #elif VW2 == 2 - MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec.x); - MultiplyAdd(acc[VW2*_w+1], xlm[_kl], avec.y); + MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec.x); + MultiplyAdd(acc2[VW2*_w+1], xlm[_kl], avec.y); #elif VW2 == 4 - MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec.x); - MultiplyAdd(acc[VW2*_w+1], xlm[_kl], avec.y); - MultiplyAdd(acc[VW2*_w+2], xlm[_kl], avec.z); - MultiplyAdd(acc[VW2*_w+3], xlm[_kl], avec.w); + MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec.x); + MultiplyAdd(acc2[VW2*_w+1], xlm[_kl], avec.y); + MultiplyAdd(acc2[VW2*_w+2], xlm[_kl], avec.z); + MultiplyAdd(acc2[VW2*_w+3], xlm[_kl], avec.w); #elif VW2 == 8 - MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec.s0); - MultiplyAdd(acc[VW2*_w+1], xlm[_kl], avec.s1); - MultiplyAdd(acc[VW2*_w+2], xlm[_kl], avec.s2); - MultiplyAdd(acc[VW2*_w+3], xlm[_kl], avec.s3); - MultiplyAdd(acc[VW2*_w+4], xlm[_kl], avec.s4); - MultiplyAdd(acc[VW2*_w+5], xlm[_kl], avec.s5); - MultiplyAdd(acc[VW2*_w+6], xlm[_kl], avec.s6); - MultiplyAdd(acc[VW2*_w+7], xlm[_kl], avec.s7); + MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec.s0); + MultiplyAdd(acc2[VW2*_w+1], xlm[_kl], avec.s1); + MultiplyAdd(acc2[VW2*_w+2], xlm[_kl], avec.s2); + MultiplyAdd(acc2[VW2*_w+3], xlm[_kl], avec.s3); + MultiplyAdd(acc2[VW2*_w+4], xlm[_kl], avec.s4); + MultiplyAdd(acc2[VW2*_w+5], xlm[_kl], avec.s5); + MultiplyAdd(acc2[VW2*_w+6], xlm[_kl], avec.s6); + MultiplyAdd(acc2[VW2*_w+7], xlm[_kl], avec.s7); #elif VW2 == 16 - MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec.s0); - MultiplyAdd(acc[VW2*_w+1], xlm[_kl], avec.s1); - MultiplyAdd(acc[VW2*_w+2], xlm[_kl], avec.s2); - MultiplyAdd(acc[VW2*_w+3], xlm[_kl], avec.s3); - MultiplyAdd(acc[VW2*_w+4], xlm[_kl], avec.s4); - MultiplyAdd(acc[VW2*_w+5], xlm[_kl], avec.s5); - MultiplyAdd(acc[VW2*_w+6], xlm[_kl], avec.s6); - MultiplyAdd(acc[VW2*_w+7], xlm[_kl], avec.s7); - MultiplyAdd(acc[VW2*_w+8], xlm[_kl], avec.s8); - MultiplyAdd(acc[VW2*_w+9], xlm[_kl], avec.s9); - MultiplyAdd(acc[VW2*_w+10], xlm[_kl], avec.sA); - MultiplyAdd(acc[VW2*_w+11], xlm[_kl], avec.sB); - MultiplyAdd(acc[VW2*_w+12], xlm[_kl], avec.sC); - MultiplyAdd(acc[VW2*_w+13], xlm[_kl], avec.sD); - MultiplyAdd(acc[VW2*_w+14], xlm[_kl], avec.sE); - MultiplyAdd(acc[VW2*_w+15], xlm[_kl], avec.sF); + MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec.s0); + MultiplyAdd(acc2[VW2*_w+1], xlm[_kl], avec.s1); + MultiplyAdd(acc2[VW2*_w+2], xlm[_kl], avec.s2); + MultiplyAdd(acc2[VW2*_w+3], xlm[_kl], avec.s3); + MultiplyAdd(acc2[VW2*_w+4], xlm[_kl], avec.s4); + MultiplyAdd(acc2[VW2*_w+5], xlm[_kl], avec.s5); + MultiplyAdd(acc2[VW2*_w+6], xlm[_kl], avec.s6); + MultiplyAdd(acc2[VW2*_w+7], xlm[_kl], avec.s7); + MultiplyAdd(acc2[VW2*_w+8], xlm[_kl], avec.s8); + MultiplyAdd(acc2[VW2*_w+9], xlm[_kl], avec.s9); + MultiplyAdd(acc2[VW2*_w+10], xlm[_kl], avec.sA); + MultiplyAdd(acc2[VW2*_w+11], xlm[_kl], avec.sB); + MultiplyAdd(acc2[VW2*_w+12], xlm[_kl], avec.sC); + MultiplyAdd(acc2[VW2*_w+13], xlm[_kl], avec.sD); + MultiplyAdd(acc2[VW2*_w+14], xlm[_kl], avec.sE); + MultiplyAdd(acc2[VW2*_w+15], xlm[_kl], avec.sF); #endif } } @@ -179,7 +179,7 @@ void XgemvFast(const int m, const int n, for (int _w = 0; _w < WPT2; _w += 1) { const int gid = WPT2*get_global_id(0) + _w; real yval = ygm[gid*y_inc + y_offset]; - AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[_w], beta, yval); + AXPBY(ygm[gid*y_inc + y_offset], alpha, acc2[_w], beta, yval); } } @@ -214,8 +214,8 @@ void XgemvFastRot(const int m, const int n, __local real xlm[WPT3]; // Initializes the accumulation register - real acc; - SetToZero(acc); + real acc3; + SetToZero(acc3); // Loops over tile-sized portions of the work for (int kwg=0; kwg<n; kwg+=WPT3) { @@ -280,7 +280,7 @@ void XgemvFastRot(const int m, const int n, for (int _v = 0; _v < VW3; _v += 1) { real aval = tile[lid_mod*VW3 + _v][lid_div * (WPT3/VW3) + _kl]; real xval = xlm[_kl*VW3 + _v]; - MultiplyAdd(acc, xval, aval); + MultiplyAdd(acc3, xval, aval); } } @@ -291,7 +291,7 @@ void XgemvFastRot(const int m, const int n, // Stores the final result const int gid = get_global_id(0); real yval = ygm[gid * y_inc + y_offset]; - AXPBY(ygm[gid * y_inc + y_offset], alpha, acc, beta, yval); + AXPBY(ygm[gid * y_inc + y_offset], alpha, acc3, beta, yval); } // ================================================================================================= |