From 0a1a3de58a410f61f3b990537541a633826ea640 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 3 Dec 2017 16:39:22 +0100 Subject: Added basic bracket parsing in defines and loop expressions --- src/kernel_preprocessor.cpp | 60 +++++++++++++++++++++++++++++++++++++-------- src/kernels/common.opencl | 34 ++++++++++++------------- 2 files changed, 67 insertions(+), 27 deletions(-) (limited to 'src') diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp index 9f79d40f..f0245ce6 100644 --- a/src/kernel_preprocessor.cpp +++ b/src/kernel_preprocessor.cpp @@ -44,30 +44,68 @@ bool HasOnlyDigits(const std::string& str) { return str.find_first_not_of(" 0123456789") == std::string::npos; } -// Converts a string to an integer. The source line is printed in case an exception is raised. -size_t StringToDigit(const std::string& str, const std::string& source_line) { +// Simple unsigned integer math parser +int ParseMath(const std::string& str) { + + // Handles brackets + const auto split_close = split(str, ')'); + if (split_close.size() >= 2) { + const auto split_end = split(split_close[0], '('); + if (split_end.size() < 2) { RaiseError(str, "Mismatching brackets #0"); } + const auto bracket_contents = ParseMath(split_end[split_end.size() - 1]); + auto before = std::string{}; + for (auto i = size_t{0}; i < split_end.size() - 1; ++i) { + before += split_end[i]; + if (i != split_end.size() - 2) { before += "("; } + } + auto after = std::string{}; + for (auto i = size_t{1}; i < split_close.size(); ++i) { + after += split_close[i]; + if (i != split_close.size() - 1) { after += ")"; } + } + return ParseMath(before + ToString(bracket_contents) + after); + } // Handles addition const auto split_add = split(str, '+'); if (split_add.size() == 2) { - return StringToDigit(split_add[0], source_line) + StringToDigit(split_add[1], source_line); + const auto lhs = ParseMath(split_add[0]); + const auto rhs = ParseMath(split_add[1]); + if (lhs == -1 || rhs == -1) { return -1; } + return lhs + rhs; } // Handles multiplication const auto split_mul = split(str, '*'); if (split_mul.size() == 2) { - return StringToDigit(split_mul[0], source_line) * StringToDigit(split_mul[1], source_line); + const auto lhs = ParseMath(split_mul[0]); + const auto rhs = ParseMath(split_mul[1]); + if (lhs == -1 || rhs == -1) { return -1; } + return lhs * rhs; } // Handles division const auto split_div = split(str, '/'); if (split_div.size() == 2) { - return StringToDigit(split_div[0], source_line) / StringToDigit(split_div[1], source_line); + const auto lhs = ParseMath(split_div[0]); + const auto rhs = ParseMath(split_div[1]); + if (lhs == -1 || rhs == -1) { return -1; } + return lhs / rhs; } // Handles the digits - if (not HasOnlyDigits(str)) { RaiseError(source_line, "Not a digit: " + str); } - return static_cast(std::stoi(str)); + if (HasOnlyDigits(str)) { + return std::stoi(str); + } + return -1; // error value +} + + +// Converts a string to an integer. The source line is printed in case an exception is raised. +size_t StringToDigit(const std::string& str, const std::string& source_line) { + const auto result = ParseMath(str); + if (result == -1) { RaiseError(source_line, "Not a digit: " + str); } + return static_cast(result); } @@ -226,10 +264,12 @@ std::vector PreprocessDefinesAndComments(const std::string& source, if (define_pos != std::string::npos) { const auto define = line.substr(define_pos + 8); // length of "#define " const auto value_pos = define.find(" "); - const auto value = define.substr(value_pos + 1); + auto value = define.substr(value_pos + 1); const auto name = define.substr(0, value_pos); - if (HasOnlyDigits(value)) { - defines_int.emplace(name, std::stoi(value)); + SubstituteDefines(defines_int, value); + const auto value_int = ParseMath(value); + if (value_int != -1) { + defines_int.emplace(name, static_cast(value_int)); } defines_string.emplace(name, value); } diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index 01c411bc..4a476a8b 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -176,61 +176,61 @@ R"( // Adds two complex variables #if PRECISION == 3232 || PRECISION == 6464 - #define Add(c, a, b) c.x = a.x + b.x; c.y = a.y + b.y + #define Add(c,a,b) c.x = a.x + b.x; c.y = a.y + b.y #else - #define Add(c, a, b) c = a + b + #define Add(c,a,b) c = a + b #endif // Subtracts two complex variables #if PRECISION == 3232 || PRECISION == 6464 - #define Subtract(c, a, b) c.x = a.x - b.x; c.y = a.y - b.y + #define Subtract(c,a,b) c.x = a.x - b.x; c.y = a.y - b.y #else - #define Subtract(c, a, b) c = a - b + #define Subtract(c,a,b) c = a - b #endif // Multiply two complex variables (used in the defines below) #if PRECISION == 3232 || PRECISION == 6464 - #define MulReal(a, b) a.x*b.x - a.y*b.y - #define MulImag(a, b) a.x*b.y + a.y*b.x + #define MulReal(a,b) a.x*b.x - a.y*b.y + #define MulImag(a,b) a.x*b.y + a.y*b.x #endif // The scalar multiply function #if PRECISION == 3232 || PRECISION == 6464 - #define Multiply(c, a, b) c.x = MulReal(a,b); c.y = MulImag(a,b) + #define Multiply(c,a,b) c.x = MulReal(a,b); c.y = MulImag(a,b) #else - #define Multiply(c, a, b) c = a * b + #define Multiply(c,a,b) c = a * b #endif // The scalar multiply-add function #if PRECISION == 3232 || PRECISION == 6464 - #define MultiplyAdd(c, a, b) c.x += MulReal(a,b); c.y += MulImag(a,b) + #define MultiplyAdd(c,a,b) c.x += MulReal(a,b); c.y += MulImag(a,b) #else #if USE_CL_MAD == 1 - #define MultiplyAdd(c, a, b) c = mad(a, b, c) + #define MultiplyAdd(c,a,b) c = mad(a, b, c) #else - #define MultiplyAdd(c, a, b) c += a * b + #define MultiplyAdd(c,a,b) c += a * b #endif #endif // The scalar multiply-subtract function #if PRECISION == 3232 || PRECISION == 6464 - #define MultiplySubtract(c, a, b) c.x -= MulReal(a,b); c.y -= MulImag(a,b) + #define MultiplySubtract(c,a,b) c.x -= MulReal(a,b); c.y -= MulImag(a,b) #else - #define MultiplySubtract(c, a, b) c -= a * b + #define MultiplySubtract(c,a,b) c -= a * b #endif // The scalar division function: full division #if PRECISION == 3232 || PRECISION == 6464 - #define DivideFull(c, a, b) singlereal num_x = (a.x * b.x) + (a.y * b.y); singlereal num_y = (a.y * b.x) - (a.x * b.y); singlereal denom = (b.x * b.x) + (b.y * b.y); c.x = num_x / denom; c.y = num_y / denom + #define DivideFull(c,a,b) singlereal num_x = (a.x * b.x) + (a.y * b.y); singlereal num_y = (a.y * b.x) - (a.x * b.y); singlereal denom = (b.x * b.x) + (b.y * b.y); c.x = num_x / denom; c.y = num_y / denom #else - #define DivideFull(c, a, b) c = a / b + #define DivideFull(c,a,b) c = a / b #endif // The scalar AXPBY function #if PRECISION == 3232 || PRECISION == 6464 - #define AXPBY(e, a, b, c, d) e.x = MulReal(a,b) + MulReal(c,d); e.y = MulImag(a,b) + MulImag(c,d) + #define AXPBY(e,a,b,c,d) e.x = MulReal(a,b) + MulReal(c,d); e.y = MulImag(a,b) + MulImag(c,d) #else - #define AXPBY(e, a, b, c, d) e = a*b + c*d + #define AXPBY(e,a,b,c,d) e = a*b + c*d #endif // The complex conjugate operation for complex transforms -- cgit v1.2.3