From 14047861ce43593f9a54253b179c305c02e46fa8 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Tue, 28 Nov 2017 20:52:08 +0100 Subject: Improved the kernel pre-processor in various ways --- src/kernel_preprocessor.cpp | 98 ++++++++++++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 28 deletions(-) diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp index 3373123f..0bb06d55 100644 --- a/src/kernel_preprocessor.cpp +++ b/src/kernel_preprocessor.cpp @@ -32,6 +32,29 @@ namespace clblast { // ================================================================================================= +void RaiseError(const std::string& source_line, const std::string& exception_message) { + printf("Error in source line: %s\n", source_line.c_str()); + throw Error(exception_message); +} + +// ================================================================================================= + +bool HasOnlyDigits(const std::string& str) { + return str.find_first_not_of("0123456789") == std::string::npos; +} + +size_t StringToDigit(const std::string& str, const std::string& source_line) { + const auto split_dividers = split(str, '/'); + if (split_dividers.size() == 2) { + return StringToDigit(split_dividers[0], source_line) / StringToDigit(split_dividers[1], source_line); + } + if (not HasOnlyDigits(str)) { RaiseError(source_line, "Not a digit: " + str); } + return static_cast(std::stoi(str)); +} + + +// ================================================================================================= + void FindReplace(std::string &subject, const std::string &search, const std::string &replace) { auto pos = size_t{0}; @@ -41,13 +64,18 @@ void FindReplace(std::string &subject, const std::string &search, const std::str } } +void SubstituteDefines(const std::unordered_map& defines, + std::string& source_string) { + for (const auto &define : defines) { + FindReplace(source_string, define.first, std::to_string(define.second)); + } +} + bool EvaluateCondition(std::string condition, const std::unordered_map &defines) { // Replace macros in the string - for (const auto &define : defines) { - FindReplace(condition, define.first, std::to_string(define.second)); - } + SubstituteDefines(defines, condition); // Process the equality sign const auto equal_pos = condition.find(" == "); @@ -68,17 +96,32 @@ std::vector PreprocessDefinesAndComments(const std::string& source, // Parse the input string into a vector of lines auto disabled = false; + auto depth = 0; auto source_stream = std::stringstream(source); auto line = std::string{""}; while (std::getline(source_stream, line)) { // Decide whether or not to remain in 'disabled' mode - if (line.find("#endif") != std::string::npos || - line.find("#elif") != std::string::npos) { - disabled = false; + if (line.find("#endif") != std::string::npos) { + if (depth == 1) { + disabled = false; + } + depth--; + } + if (depth == 1) { + if (line.find("#elif") != std::string::npos) { + disabled = false; + } + if (line.find("#else") != std::string::npos) { + disabled = !disabled; + } } - if (line.find("#else") != std::string::npos) { - disabled = !disabled; + + // Measures the depth of pre-processor defines + if ((line.find("#ifndef ") != std::string::npos) || + (line.find("#ifdef ") != std::string::npos) || + (line.find("#if ") != std::string::npos)) { + depth++; } // Not in a disabled-block @@ -150,13 +193,6 @@ std::vector PreprocessDefinesAndComments(const std::string& source, // ================================================================================================= -inline void SubstituteDefines(const std::unordered_map& defines, - std::string& source_string) { - if (defines.count(source_string) == 1) { - source_string = ToString(defines.at(source_string)); - } -} - // Second pass: unroll loops std::vector PreprocessUnrollLoops(const std::vector& source_lines, const std::unordered_map& defines, @@ -193,19 +229,19 @@ std::vector PreprocessUnrollLoops(const std::vector& s if (promote_next_array_to_registers) { promote_next_array_to_registers = false; const auto line_split1 = split(line, '['); - if (line_split1.size() != 2) { throw Error("Mis-formatted array declaration #0"); } + if (line_split1.size() != 2) { RaiseError(line, "Mis-formatted array declaration #0"); } const auto line_split2 = split(line_split1[1], ']'); - if (line_split2.size() != 2) { throw Error("Mis-formatted array declaration #1"); } + if (line_split2.size() != 2) { RaiseError(line, "Mis-formatted array declaration #1"); } auto array_size_string = line_split2[0]; SubstituteDefines(defines, array_size_string); - const auto array_size = std::stoi(array_size_string); - for (auto loop_iter = 0; loop_iter < array_size; ++loop_iter) { + const auto array_size = StringToDigit(array_size_string, line); + for (auto loop_iter = size_t{0}; loop_iter < array_size; ++loop_iter) { lines.emplace_back(line_split1[0] + "_" + ToString(loop_iter) + line_split2[1]); } // Stores the array name const auto array_name_split = split(line_split1[0], ' '); - if (array_name_split.size() < 2) { throw Error("Mis-formatted array declaration #2"); } + if (array_name_split.size() < 2) { RaiseError(line, "Mis-formatted array declaration #2"); } const auto array_name = array_name_split[array_name_split.size() - 1]; arrays_to_registers[array_name] = brackets; // TODO: bracket count not used currently for scope checking continue; @@ -218,16 +254,16 @@ std::vector PreprocessUnrollLoops(const std::vector& s // Parses loop structure const auto for_pos = line.find("for ("); - if (for_pos == std::string::npos) { throw Error("Mis-formatted for-loop #0"); } + if (for_pos == std::string::npos) { RaiseError(line, "Mis-formatted for-loop #0"); } const auto remainder = line.substr(for_pos + 5); // length of "for (" const auto line_split = split(remainder, ' '); - if (line_split.size() != 11) { throw Error("Mis-formatted for-loop #1"); } + if (line_split.size() != 11) { RaiseError(line, "Mis-formatted for-loop #1"); } // Retrieves loop information (and checks for assumptions) const auto variable_type = line_split[0]; const auto variable_name = line_split[1]; - if (variable_name != line_split[4]) { throw Error("Mis-formatted for-loop #2"); } - if (variable_name != line_split[7]) { throw Error("Mis-formatted for-loop #3"); } + if (variable_name != line_split[4]) { RaiseError(line, "Mis-formatted for-loop #2"); } + if (variable_name != line_split[7]) { RaiseError(line, "Mis-formatted for-loop #3"); } auto loop_start_string = line_split[3]; auto loop_end_string = line_split[6]; auto loop_increment_string = line_split[9]; @@ -239,9 +275,9 @@ std::vector PreprocessUnrollLoops(const std::vector& s SubstituteDefines(defines, loop_start_string); SubstituteDefines(defines, loop_end_string); SubstituteDefines(defines, loop_increment_string); - const auto loop_start = std::stoi(loop_start_string); - const auto loop_end = std::stoi(loop_end_string); - const auto loop_increment = std::stoi(loop_increment_string); + const auto loop_start = StringToDigit(loop_start_string, line); + const auto loop_end = StringToDigit(loop_end_string, line); + const auto loop_increment = StringToDigit(loop_increment_string, line); auto indent = std::string{""}; for (auto i = size_t{0}; i < for_pos; ++i) { indent += " "; } @@ -264,7 +300,6 @@ std::vector PreprocessUnrollLoops(const std::vector& s // Array to register promotion, e.g. arr[w] to {arr_0, arr_1} if (array_to_register_promotion) { for (const auto array_name_map : arrays_to_registers) { // only if marked to be promoted - printf("%s: %s\n", loop_line.c_str(), array_name_map.first.c_str()); FindReplace(loop_line, array_name_map.first + "[" + variable_name + "]", array_name_map.first + "_" + ToString(loop_iter)); } @@ -303,6 +338,13 @@ std::string PreprocessKernelSource(const std::string& kernel_source) { for (const auto& line : lines) { processed_kernel += line + "\n"; } + + // Debugging + if (false) { + for (auto i = size_t{0}; i < lines.size(); ++i) { + printf("[%zu] %s\n", i, lines[i].c_str()); + } + } return processed_kernel; } -- cgit v1.2.3