summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-12-03 16:39:22 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-12-03 16:39:22 +0100
commit0a1a3de58a410f61f3b990537541a633826ea640 (patch)
tree445cbb7d6ad576324753b5b4f3cc613cb14aa0f7 /src
parent60312e5878fd45225158dd8545a01366f937a871 (diff)
Added basic bracket parsing in defines and loop expressions
Diffstat (limited to 'src')
-rw-r--r--src/kernel_preprocessor.cpp60
-rw-r--r--src/kernels/common.opencl34
2 files changed, 67 insertions, 27 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp
index 9f79d40f..f0245ce6 100644
--- a/src/kernel_preprocessor.cpp
+++ b/src/kernel_preprocessor.cpp
@@ -44,30 +44,68 @@ bool HasOnlyDigits(const std::string& str) {
return str.find_first_not_of(" 0123456789") == std::string::npos;
}
-// Converts a string to an integer. The source line is printed in case an exception is raised.
-size_t StringToDigit(const std::string& str, const std::string& source_line) {
+// Simple unsigned integer math parser
+int ParseMath(const std::string& str) {
+
+ // Handles brackets
+ const auto split_close = split(str, ')');
+ if (split_close.size() >= 2) {
+ const auto split_end = split(split_close[0], '(');
+ if (split_end.size() < 2) { RaiseError(str, "Mismatching brackets #0"); }
+ const auto bracket_contents = ParseMath(split_end[split_end.size() - 1]);
+ auto before = std::string{};
+ for (auto i = size_t{0}; i < split_end.size() - 1; ++i) {
+ before += split_end[i];
+ if (i != split_end.size() - 2) { before += "("; }
+ }
+ auto after = std::string{};
+ for (auto i = size_t{1}; i < split_close.size(); ++i) {
+ after += split_close[i];
+ if (i != split_close.size() - 1) { after += ")"; }
+ }
+ return ParseMath(before + ToString(bracket_contents) + after);
+ }
// Handles addition
const auto split_add = split(str, '+');
if (split_add.size() == 2) {
- return StringToDigit(split_add[0], source_line) + StringToDigit(split_add[1], source_line);
+ const auto lhs = ParseMath(split_add[0]);
+ const auto rhs = ParseMath(split_add[1]);
+ if (lhs == -1 || rhs == -1) { return -1; }
+ return lhs + rhs;
}
// Handles multiplication
const auto split_mul = split(str, '*');
if (split_mul.size() == 2) {
- return StringToDigit(split_mul[0], source_line) * StringToDigit(split_mul[1], source_line);
+ const auto lhs = ParseMath(split_mul[0]);
+ const auto rhs = ParseMath(split_mul[1]);
+ if (lhs == -1 || rhs == -1) { return -1; }
+ return lhs * rhs;
}
// Handles division
const auto split_div = split(str, '/');
if (split_div.size() == 2) {
- return StringToDigit(split_div[0], source_line) / StringToDigit(split_div[1], source_line);
+ const auto lhs = ParseMath(split_div[0]);
+ const auto rhs = ParseMath(split_div[1]);
+ if (lhs == -1 || rhs == -1) { return -1; }
+ return lhs / rhs;
}
// Handles the digits
- if (not HasOnlyDigits(str)) { RaiseError(source_line, "Not a digit: " + str); }
- return static_cast<size_t>(std::stoi(str));
+ if (HasOnlyDigits(str)) {
+ return std::stoi(str);
+ }
+ return -1; // error value
+}
+
+
+// Converts a string to an integer. The source line is printed in case an exception is raised.
+size_t StringToDigit(const std::string& str, const std::string& source_line) {
+ const auto result = ParseMath(str);
+ if (result == -1) { RaiseError(source_line, "Not a digit: " + str); }
+ return static_cast<size_t>(result);
}
@@ -226,10 +264,12 @@ std::vector<std::string> PreprocessDefinesAndComments(const std::string& source,
if (define_pos != std::string::npos) {
const auto define = line.substr(define_pos + 8); // length of "#define "
const auto value_pos = define.find(" ");
- const auto value = define.substr(value_pos + 1);
+ auto value = define.substr(value_pos + 1);
const auto name = define.substr(0, value_pos);
- if (HasOnlyDigits(value)) {
- defines_int.emplace(name, std::stoi(value));
+ SubstituteDefines(defines_int, value);
+ const auto value_int = ParseMath(value);
+ if (value_int != -1) {
+ defines_int.emplace(name, static_cast<size_t>(value_int));
}
defines_string.emplace(name, value);
}
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl
index 01c411bc..4a476a8b 100644
--- a/src/kernels/common.opencl
+++ b/src/kernels/common.opencl
@@ -176,61 +176,61 @@ R"(
// Adds two complex variables
#if PRECISION == 3232 || PRECISION == 6464
- #define Add(c, a, b) c.x = a.x + b.x; c.y = a.y + b.y
+ #define Add(c,a,b) c.x = a.x + b.x; c.y = a.y + b.y
#else
- #define Add(c, a, b) c = a + b
+ #define Add(c,a,b) c = a + b
#endif
// Subtracts two complex variables
#if PRECISION == 3232 || PRECISION == 6464
- #define Subtract(c, a, b) c.x = a.x - b.x; c.y = a.y - b.y
+ #define Subtract(c,a,b) c.x = a.x - b.x; c.y = a.y - b.y
#else
- #define Subtract(c, a, b) c = a - b
+ #define Subtract(c,a,b) c = a - b
#endif
// Multiply two complex variables (used in the defines below)
#if PRECISION == 3232 || PRECISION == 6464
- #define MulReal(a, b) a.x*b.x - a.y*b.y
- #define MulImag(a, b) a.x*b.y + a.y*b.x
+ #define MulReal(a,b) a.x*b.x - a.y*b.y
+ #define MulImag(a,b) a.x*b.y + a.y*b.x
#endif
// The scalar multiply function
#if PRECISION == 3232 || PRECISION == 6464
- #define Multiply(c, a, b) c.x = MulReal(a,b); c.y = MulImag(a,b)
+ #define Multiply(c,a,b) c.x = MulReal(a,b); c.y = MulImag(a,b)
#else
- #define Multiply(c, a, b) c = a * b
+ #define Multiply(c,a,b) c = a * b
#endif
// The scalar multiply-add function
#if PRECISION == 3232 || PRECISION == 6464
- #define MultiplyAdd(c, a, b) c.x += MulReal(a,b); c.y += MulImag(a,b)
+ #define MultiplyAdd(c,a,b) c.x += MulReal(a,b); c.y += MulImag(a,b)
#else
#if USE_CL_MAD == 1
- #define MultiplyAdd(c, a, b) c = mad(a, b, c)
+ #define MultiplyAdd(c,a,b) c = mad(a, b, c)
#else
- #define MultiplyAdd(c, a, b) c += a * b
+ #define MultiplyAdd(c,a,b) c += a * b
#endif
#endif
// The scalar multiply-subtract function
#if PRECISION == 3232 || PRECISION == 6464
- #define MultiplySubtract(c, a, b) c.x -= MulReal(a,b); c.y -= MulImag(a,b)
+ #define MultiplySubtract(c,a,b) c.x -= MulReal(a,b); c.y -= MulImag(a,b)
#else
- #define MultiplySubtract(c, a, b) c -= a * b
+ #define MultiplySubtract(c,a,b) c -= a * b
#endif
// The scalar division function: full division
#if PRECISION == 3232 || PRECISION == 6464
- #define DivideFull(c, a, b) singlereal num_x = (a.x * b.x) + (a.y * b.y); singlereal num_y = (a.y * b.x) - (a.x * b.y); singlereal denom = (b.x * b.x) + (b.y * b.y); c.x = num_x / denom; c.y = num_y / denom
+ #define DivideFull(c,a,b) singlereal num_x = (a.x * b.x) + (a.y * b.y); singlereal num_y = (a.y * b.x) - (a.x * b.y); singlereal denom = (b.x * b.x) + (b.y * b.y); c.x = num_x / denom; c.y = num_y / denom
#else
- #define DivideFull(c, a, b) c = a / b
+ #define DivideFull(c,a,b) c = a / b
#endif
// The scalar AXPBY function
#if PRECISION == 3232 || PRECISION == 6464
- #define AXPBY(e, a, b, c, d) e.x = MulReal(a,b) + MulReal(c,d); e.y = MulImag(a,b) + MulImag(c,d)
+ #define AXPBY(e,a,b,c,d) e.x = MulReal(a,b) + MulReal(c,d); e.y = MulImag(a,b) + MulImag(c,d)
#else
- #define AXPBY(e, a, b, c, d) e = a*b + c*d
+ #define AXPBY(e,a,b,c,d) e = a*b + c*d
#endif
// The complex conjugate operation for complex transforms