diff options
Diffstat (limited to 'src/kernel_preprocessor.cpp')
-rw-r--r-- | src/kernel_preprocessor.cpp | 155 |
1 files changed, 120 insertions, 35 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp index f0245ce6..6239361b 100644 --- a/src/kernel_preprocessor.cpp +++ b/src/kernel_preprocessor.cpp @@ -18,6 +18,8 @@ // The above also requires the spaces in that exact form // - The loop variable should be a unique string within the code in the for-loop body (e.g. don't // use 'i' or 'w' but rather '_w' or a longer name. +// - The pragma "#pragma promote_to_registers" unrolls an array into multiple scalar values. The +// name of this scalar should be unique (see above). // // ================================================================================================= @@ -176,26 +178,85 @@ bool EvaluateCondition(std::string condition, // Array to register promotion, e.g. arr[w] to {arr_0, arr_1} void ArrayToRegister(std::string &source_line, const std::unordered_map<std::string, int>& defines, - const std::unordered_map<std::string, size_t>& arrays_to_registers) { + const std::unordered_map<std::string, size_t>& arrays_to_registers, + const size_t num_brackets) { for (const auto array_name_map : arrays_to_registers) { // only if marked to be promoted - auto array_pos = source_line.find(array_name_map.first + "["); - while (array_pos != std::string::npos) { - - // Retrieves the array index - const auto loop_remainder = source_line.substr(array_pos); - const auto loop_split = split(split(loop_remainder, '[')[1], ']'); - if (loop_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration"); } - auto array_index_string = loop_split[0]; - - // Replaces the array with a register value - SubstituteDefines(defines, array_index_string); - const auto array_index = StringToDigit(array_index_string, source_line); - FindReplace(source_line, array_name_map.first + "[" + loop_split[0] + "]", - array_name_map.first + "_" + ToString(array_index)); - - // Performs an extra substitution if this array occurs another time in this line - array_pos = source_line.find(array_name_map.first + "["); + + // Outside of a function + if (num_brackets == 0) { + + // Case 1: argument in a function declaration (e.g. 'void func(const float arr[2])') + const auto array_pos = source_line.find(array_name_map.first + "["); + if (array_pos != std::string::npos) { + SubstituteDefines(defines, source_line); + + // Finds the full array declaration (e.g. 'const float arr[2]') + const auto left_split = split(source_line, '('); + auto arguments = left_split.size() >= 2 ? left_split[1] : source_line; + const auto right_split = split(arguments, ')'); + arguments = right_split.size() >= 1 ? right_split[0] : arguments; + const auto comma_split = split(arguments, ','); + for (auto j = size_t{0}; j < comma_split.size(); ++j) { + if (comma_split[j].find(array_name_map.first + "[") != std::string::npos) { + + // Retrieves the array index + const auto left_square_split = split(comma_split[j], '['); + if (left_square_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration #A"); } + const auto right_square_split = split(left_square_split[1], ']'); + if (right_square_split.size() < 1) { RaiseError(source_line, "Mis-formatted array declaration #B"); } + auto array_index_string = right_square_split[0]; + const auto array_index = StringToDigit(array_index_string, source_line); + + // Creates the new string + auto replacement = std::string{}; + for (auto index = size_t{0}; index < array_index; ++index) { + replacement += left_square_split[0] + "_" + ToString(index); + if (index != array_index - 1) { replacement += ","; } + } + + // Performs the actual replacement + FindReplace(source_line, comma_split[j], replacement); + } + } + } + } + + // Inside a function + else { + auto array_pos = source_line.find(array_name_map.first + "["); + + // Case 2: passed to another function (e.g. 'func(arr)') + if (array_pos == std::string::npos) { // assumes case 2 and case 3 (below) cannot occur in one line + auto bracket_split = split(source_line, '('); + if (bracket_split.size() >= 2) { + auto replacement = std::string{}; + for (auto i = size_t{0}; i < array_name_map.second; ++i) { + replacement += array_name_map.first + "_" + ToString(i); + if (i != array_name_map.second - 1) { replacement += ", "; } + } + FindReplace(source_line, array_name_map.first, replacement); + } + } + + // Case 2: used as an array (e.g. 'arr[w]') + while (array_pos != std::string::npos) { + + // Retrieves the array index + const auto loop_remainder = source_line.substr(array_pos); + const auto loop_split = split(split(loop_remainder, '[')[1], ']'); + if (loop_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration #C"); } + auto array_index_string = loop_split[0]; + + // Replaces the array with a register value + SubstituteDefines(defines, array_index_string); + const auto array_index = StringToDigit(array_index_string, source_line); + FindReplace(source_line, array_name_map.first + "[" + loop_split[0] + "]", + array_name_map.first + "_" + ToString(array_index)); + + // Performs an extra substitution if this array occurs another time in this line + array_pos = source_line.find(array_name_map.first + "["); + } } } } @@ -319,31 +380,21 @@ std::vector<std::string> PreprocessDefinesAndComments(const std::string& source, // ================================================================================================= -// Second pass: unroll loops +// Second pass: detect array-to-register promotion pragma's and replace declarations & function calls std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines, const std::unordered_map<std::string, int>& defines, - std::unordered_map<std::string, size_t>& arrays_to_registers, - const bool array_to_register_promotion) { + std::unordered_map<std::string, size_t>& arrays_to_registers) { auto lines = std::vector<std::string>(); auto brackets = size_t{0}; - auto unroll_next_loop = false; auto promote_next_array_to_registers = false; for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) { auto line = source_lines[line_id]; // Detect #pragma promote_to_registers directives (unofficial pragma) - if (array_to_register_promotion) { - if (line.find("#pragma promote_to_registers") != std::string::npos) { - promote_next_array_to_registers = true; - continue; - } - } - - // Detect #pragma unroll directives - if (line.find("#pragma unroll") != std::string::npos) { - unroll_next_loop = true; + if (line.find("#pragma promote_to_registers") != std::string::npos) { + promote_next_array_to_registers = true; continue; } @@ -369,10 +420,43 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s const auto array_name_split = split(line_split1[0], ' '); if (array_name_split.size() < 2) { RaiseError(line, "Mis-formatted array declaration #2"); } const auto array_name = array_name_split[array_name_split.size() - 1]; - arrays_to_registers[array_name] = brackets; // TODO: bracket count not used currently for scope checking + arrays_to_registers[array_name] = array_size; + // TODO: bracket count not used currently for scope checking continue; } + // Regular line + lines.emplace_back(line); + } + return lines; +} + +// ================================================================================================= + +// Third pass: unroll loops and perform actual array-to-register promotion +std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines, + const std::unordered_map<std::string, int>& defines, + std::unordered_map<std::string, size_t>& arrays_to_registers, + const bool array_to_register_promotion) { + auto lines = std::vector<std::string>(); + + auto brackets = size_t{0}; + auto unroll_next_loop = false; + + for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) { + auto line = source_lines[line_id]; + + // Detect #pragma unroll directives + if (line.find("#pragma unroll") != std::string::npos) { + unroll_next_loop = true; + continue; + } + + // Brackets + const auto num_brackets_before = brackets; + brackets += std::count(line.begin(), line.end(), '{'); + brackets -= std::count(line.begin(), line.end(), '}'); + // Loop unrolling assuming it to be in the form "for (int w = 0; w < 4; w += 1) {" if (unroll_next_loop) { unroll_next_loop = false; @@ -427,7 +511,7 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s // Array to register promotion if (array_to_register_promotion) { - ArrayToRegister(loop_line, defines, arrays_to_registers); + ArrayToRegister(loop_line, defines, arrays_to_registers, num_brackets_before); } lines.emplace_back(loop_line); @@ -440,7 +524,7 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s // Array to register promotion if (array_to_register_promotion) { - ArrayToRegister(line, defines, arrays_to_registers); + ArrayToRegister(line, defines, arrays_to_registers, num_brackets_before); } lines.emplace_back(line); @@ -459,6 +543,7 @@ std::string PreprocessKernelSource(const std::string& kernel_source) { // Unrolls loops (single level each call) auto arrays_to_registers = std::unordered_map<std::string, size_t>(); + lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false); lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, true); |