summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-12-05 20:39:49 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-12-05 20:39:49 +0100
commit0f9637bbac6248a381d7012d7224331d3d394efb (patch)
tree958093804cd0f1be907a2b748c0477fd811cbb35
parentcf4555d1f44aea9c82b60211b5650b6b77a1226c (diff)
Improved array-to-register promotion, now handling function calls as well
-rw-r--r--src/kernel_preprocessor.cpp155
-rw-r--r--src/kernels/level2/xgemv.opencl14
-rw-r--r--src/kernels/level2/xgemv_fast.opencl76
-rw-r--r--test/correctness/misc/preprocessor.cpp49
4 files changed, 214 insertions, 80 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp
index f0245ce6..6239361b 100644
--- a/src/kernel_preprocessor.cpp
+++ b/src/kernel_preprocessor.cpp
@@ -18,6 +18,8 @@
// The above also requires the spaces in that exact form
// - The loop variable should be a unique string within the code in the for-loop body (e.g. don't
// use 'i' or 'w' but rather '_w' or a longer name.
+// - The pragma "#pragma promote_to_registers" unrolls an array into multiple scalar values. The
+// name of this scalar should be unique (see above).
//
// =================================================================================================
@@ -176,26 +178,85 @@ bool EvaluateCondition(std::string condition,
// Array to register promotion, e.g. arr[w] to {arr_0, arr_1}
void ArrayToRegister(std::string &source_line, const std::unordered_map<std::string, int>& defines,
- const std::unordered_map<std::string, size_t>& arrays_to_registers) {
+ const std::unordered_map<std::string, size_t>& arrays_to_registers,
+ const size_t num_brackets) {
for (const auto array_name_map : arrays_to_registers) { // only if marked to be promoted
- auto array_pos = source_line.find(array_name_map.first + "[");
- while (array_pos != std::string::npos) {
-
- // Retrieves the array index
- const auto loop_remainder = source_line.substr(array_pos);
- const auto loop_split = split(split(loop_remainder, '[')[1], ']');
- if (loop_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration"); }
- auto array_index_string = loop_split[0];
-
- // Replaces the array with a register value
- SubstituteDefines(defines, array_index_string);
- const auto array_index = StringToDigit(array_index_string, source_line);
- FindReplace(source_line, array_name_map.first + "[" + loop_split[0] + "]",
- array_name_map.first + "_" + ToString(array_index));
-
- // Performs an extra substitution if this array occurs another time in this line
- array_pos = source_line.find(array_name_map.first + "[");
+
+ // Outside of a function
+ if (num_brackets == 0) {
+
+ // Case 1: argument in a function declaration (e.g. 'void func(const float arr[2])')
+ const auto array_pos = source_line.find(array_name_map.first + "[");
+ if (array_pos != std::string::npos) {
+ SubstituteDefines(defines, source_line);
+
+ // Finds the full array declaration (e.g. 'const float arr[2]')
+ const auto left_split = split(source_line, '(');
+ auto arguments = left_split.size() >= 2 ? left_split[1] : source_line;
+ const auto right_split = split(arguments, ')');
+ arguments = right_split.size() >= 1 ? right_split[0] : arguments;
+ const auto comma_split = split(arguments, ',');
+ for (auto j = size_t{0}; j < comma_split.size(); ++j) {
+ if (comma_split[j].find(array_name_map.first + "[") != std::string::npos) {
+
+ // Retrieves the array index
+ const auto left_square_split = split(comma_split[j], '[');
+ if (left_square_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration #A"); }
+ const auto right_square_split = split(left_square_split[1], ']');
+ if (right_square_split.size() < 1) { RaiseError(source_line, "Mis-formatted array declaration #B"); }
+ auto array_index_string = right_square_split[0];
+ const auto array_index = StringToDigit(array_index_string, source_line);
+
+ // Creates the new string
+ auto replacement = std::string{};
+ for (auto index = size_t{0}; index < array_index; ++index) {
+ replacement += left_square_split[0] + "_" + ToString(index);
+ if (index != array_index - 1) { replacement += ","; }
+ }
+
+ // Performs the actual replacement
+ FindReplace(source_line, comma_split[j], replacement);
+ }
+ }
+ }
+ }
+
+ // Inside a function
+ else {
+ auto array_pos = source_line.find(array_name_map.first + "[");
+
+ // Case 2: passed to another function (e.g. 'func(arr)')
+ if (array_pos == std::string::npos) { // assumes case 2 and case 3 (below) cannot occur in one line
+ auto bracket_split = split(source_line, '(');
+ if (bracket_split.size() >= 2) {
+ auto replacement = std::string{};
+ for (auto i = size_t{0}; i < array_name_map.second; ++i) {
+ replacement += array_name_map.first + "_" + ToString(i);
+ if (i != array_name_map.second - 1) { replacement += ", "; }
+ }
+ FindReplace(source_line, array_name_map.first, replacement);
+ }
+ }
+
+ // Case 2: used as an array (e.g. 'arr[w]')
+ while (array_pos != std::string::npos) {
+
+ // Retrieves the array index
+ const auto loop_remainder = source_line.substr(array_pos);
+ const auto loop_split = split(split(loop_remainder, '[')[1], ']');
+ if (loop_split.size() < 2) { RaiseError(source_line, "Mis-formatted array declaration #C"); }
+ auto array_index_string = loop_split[0];
+
+ // Replaces the array with a register value
+ SubstituteDefines(defines, array_index_string);
+ const auto array_index = StringToDigit(array_index_string, source_line);
+ FindReplace(source_line, array_name_map.first + "[" + loop_split[0] + "]",
+ array_name_map.first + "_" + ToString(array_index));
+
+ // Performs an extra substitution if this array occurs another time in this line
+ array_pos = source_line.find(array_name_map.first + "[");
+ }
}
}
}
@@ -319,31 +380,21 @@ std::vector<std::string> PreprocessDefinesAndComments(const std::string& source,
// =================================================================================================
-// Second pass: unroll loops
+// Second pass: detect array-to-register promotion pragma's and replace declarations & function calls
std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines,
const std::unordered_map<std::string, int>& defines,
- std::unordered_map<std::string, size_t>& arrays_to_registers,
- const bool array_to_register_promotion) {
+ std::unordered_map<std::string, size_t>& arrays_to_registers) {
auto lines = std::vector<std::string>();
auto brackets = size_t{0};
- auto unroll_next_loop = false;
auto promote_next_array_to_registers = false;
for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) {
auto line = source_lines[line_id];
// Detect #pragma promote_to_registers directives (unofficial pragma)
- if (array_to_register_promotion) {
- if (line.find("#pragma promote_to_registers") != std::string::npos) {
- promote_next_array_to_registers = true;
- continue;
- }
- }
-
- // Detect #pragma unroll directives
- if (line.find("#pragma unroll") != std::string::npos) {
- unroll_next_loop = true;
+ if (line.find("#pragma promote_to_registers") != std::string::npos) {
+ promote_next_array_to_registers = true;
continue;
}
@@ -369,10 +420,43 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s
const auto array_name_split = split(line_split1[0], ' ');
if (array_name_split.size() < 2) { RaiseError(line, "Mis-formatted array declaration #2"); }
const auto array_name = array_name_split[array_name_split.size() - 1];
- arrays_to_registers[array_name] = brackets; // TODO: bracket count not used currently for scope checking
+ arrays_to_registers[array_name] = array_size;
+ // TODO: bracket count not used currently for scope checking
continue;
}
+ // Regular line
+ lines.emplace_back(line);
+ }
+ return lines;
+}
+
+// =================================================================================================
+
+// Third pass: unroll loops and perform actual array-to-register promotion
+std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines,
+ const std::unordered_map<std::string, int>& defines,
+ std::unordered_map<std::string, size_t>& arrays_to_registers,
+ const bool array_to_register_promotion) {
+ auto lines = std::vector<std::string>();
+
+ auto brackets = size_t{0};
+ auto unroll_next_loop = false;
+
+ for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) {
+ auto line = source_lines[line_id];
+
+ // Detect #pragma unroll directives
+ if (line.find("#pragma unroll") != std::string::npos) {
+ unroll_next_loop = true;
+ continue;
+ }
+
+ // Brackets
+ const auto num_brackets_before = brackets;
+ brackets += std::count(line.begin(), line.end(), '{');
+ brackets -= std::count(line.begin(), line.end(), '}');
+
// Loop unrolling assuming it to be in the form "for (int w = 0; w < 4; w += 1) {"
if (unroll_next_loop) {
unroll_next_loop = false;
@@ -427,7 +511,7 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s
// Array to register promotion
if (array_to_register_promotion) {
- ArrayToRegister(loop_line, defines, arrays_to_registers);
+ ArrayToRegister(loop_line, defines, arrays_to_registers, num_brackets_before);
}
lines.emplace_back(loop_line);
@@ -440,7 +524,7 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s
// Array to register promotion
if (array_to_register_promotion) {
- ArrayToRegister(line, defines, arrays_to_registers);
+ ArrayToRegister(line, defines, arrays_to_registers, num_brackets_before);
}
lines.emplace_back(line);
@@ -459,6 +543,7 @@ std::string PreprocessKernelSource(const std::string& kernel_source) {
// Unrolls loops (single level each call)
auto arrays_to_registers = std::unordered_map<std::string, size_t>();
+ lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers);
lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, false);
lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, true);
diff --git a/src/kernels/level2/xgemv.opencl b/src/kernels/level2/xgemv.opencl
index 2a50e8fb..ba29aba6 100644
--- a/src/kernels/level2/xgemv.opencl
+++ b/src/kernels/level2/xgemv.opencl
@@ -228,10 +228,10 @@ void Xgemv(const int m, const int n,
// Initializes the accumulation register
#pragma promote_to_registers
- real acc[WPT1];
+ real acc1[WPT1];
#pragma unroll
for (int _w = 0; _w < WPT1; _w += 1) {
- SetToZero(acc[_w]);
+ SetToZero(acc1[_w]);
}
// Divides the work in a main and tail section
@@ -262,7 +262,7 @@ void Xgemv(const int m, const int n,
const int k = kwg + kloop + _kunroll;
real value = LoadMatrixA(agm, gid, k, a_ld, a_offset, parameter, kl, ku);
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
- MultiplyAdd(acc[_w], xlm[kloop + _kunroll], value);
+ MultiplyAdd(acc1[_w], xlm[kloop + _kunroll], value);
}
}
}
@@ -273,7 +273,7 @@ void Xgemv(const int m, const int n,
const int k = kwg + kloop + _kunroll;
real value = LoadMatrixA(agm, k, gid, a_ld, a_offset, parameter, kl, ku);
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
- MultiplyAdd(acc[_w], xlm[kloop + _kunroll], value);
+ MultiplyAdd(acc1[_w], xlm[kloop + _kunroll], value);
}
}
}
@@ -295,20 +295,20 @@ void Xgemv(const int m, const int n,
for (int k=n_floor; k<n; ++k) {
real value = LoadMatrixA(agm, gid, k, a_ld, a_offset, parameter, kl, ku);
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
- MultiplyAdd(acc[_w], xgm[k*x_inc + x_offset], value);
+ MultiplyAdd(acc1[_w], xgm[k*x_inc + x_offset], value);
}
}
else { // Transposed
for (int k=n_floor; k<n; ++k) {
real value = LoadMatrixA(agm, k, gid, a_ld, a_offset, parameter, kl, ku);
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
- MultiplyAdd(acc[_w], xgm[k*x_inc + x_offset], value);
+ MultiplyAdd(acc1[_w], xgm[k*x_inc + x_offset], value);
}
}
// Stores the final result
real yval = ygm[gid*y_inc + y_offset];
- AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[_w], beta, yval);
+ AXPBY(ygm[gid*y_inc + y_offset], alpha, acc1[_w], beta, yval);
}
}
}
diff --git a/src/kernels/level2/xgemv_fast.opencl b/src/kernels/level2/xgemv_fast.opencl
index 892bc55c..45ceb36c 100644
--- a/src/kernels/level2/xgemv_fast.opencl
+++ b/src/kernels/level2/xgemv_fast.opencl
@@ -106,10 +106,10 @@ void XgemvFast(const int m, const int n,
// Initializes the accumulation registers
#pragma promote_to_registers
- real acc[WPT2];
+ real acc2[WPT2];
#pragma unroll
for (int _w = 0; _w < WPT2; _w += 1) {
- SetToZero(acc[_w]);
+ SetToZero(acc2[_w]);
}
// Loops over work-group sized portions of the work
@@ -131,41 +131,41 @@ void XgemvFast(const int m, const int n,
const int gid = (WPT2/VW2)*get_global_id(0) + _w;
realVF avec = agm[(a_ld/VW2)*k + gid];
#if VW2 == 1
- MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec);
+ MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec);
#elif VW2 == 2
- MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec.x);
- MultiplyAdd(acc[VW2*_w+1], xlm[_kl], avec.y);
+ MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec.x);
+ MultiplyAdd(acc2[VW2*_w+1], xlm[_kl], avec.y);
#elif VW2 == 4
- MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec.x);
- MultiplyAdd(acc[VW2*_w+1], xlm[_kl], avec.y);
- MultiplyAdd(acc[VW2*_w+2], xlm[_kl], avec.z);
- MultiplyAdd(acc[VW2*_w+3], xlm[_kl], avec.w);
+ MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec.x);
+ MultiplyAdd(acc2[VW2*_w+1], xlm[_kl], avec.y);
+ MultiplyAdd(acc2[VW2*_w+2], xlm[_kl], avec.z);
+ MultiplyAdd(acc2[VW2*_w+3], xlm[_kl], avec.w);
#elif VW2 == 8
- MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec.s0);
- MultiplyAdd(acc[VW2*_w+1], xlm[_kl], avec.s1);
- MultiplyAdd(acc[VW2*_w+2], xlm[_kl], avec.s2);
- MultiplyAdd(acc[VW2*_w+3], xlm[_kl], avec.s3);
- MultiplyAdd(acc[VW2*_w+4], xlm[_kl], avec.s4);
- MultiplyAdd(acc[VW2*_w+5], xlm[_kl], avec.s5);
- MultiplyAdd(acc[VW2*_w+6], xlm[_kl], avec.s6);
- MultiplyAdd(acc[VW2*_w+7], xlm[_kl], avec.s7);
+ MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec.s0);
+ MultiplyAdd(acc2[VW2*_w+1], xlm[_kl], avec.s1);
+ MultiplyAdd(acc2[VW2*_w+2], xlm[_kl], avec.s2);
+ MultiplyAdd(acc2[VW2*_w+3], xlm[_kl], avec.s3);
+ MultiplyAdd(acc2[VW2*_w+4], xlm[_kl], avec.s4);
+ MultiplyAdd(acc2[VW2*_w+5], xlm[_kl], avec.s5);
+ MultiplyAdd(acc2[VW2*_w+6], xlm[_kl], avec.s6);
+ MultiplyAdd(acc2[VW2*_w+7], xlm[_kl], avec.s7);
#elif VW2 == 16
- MultiplyAdd(acc[VW2*_w+0], xlm[_kl], avec.s0);
- MultiplyAdd(acc[VW2*_w+1], xlm[_kl], avec.s1);
- MultiplyAdd(acc[VW2*_w+2], xlm[_kl], avec.s2);
- MultiplyAdd(acc[VW2*_w+3], xlm[_kl], avec.s3);
- MultiplyAdd(acc[VW2*_w+4], xlm[_kl], avec.s4);
- MultiplyAdd(acc[VW2*_w+5], xlm[_kl], avec.s5);
- MultiplyAdd(acc[VW2*_w+6], xlm[_kl], avec.s6);
- MultiplyAdd(acc[VW2*_w+7], xlm[_kl], avec.s7);
- MultiplyAdd(acc[VW2*_w+8], xlm[_kl], avec.s8);
- MultiplyAdd(acc[VW2*_w+9], xlm[_kl], avec.s9);
- MultiplyAdd(acc[VW2*_w+10], xlm[_kl], avec.sA);
- MultiplyAdd(acc[VW2*_w+11], xlm[_kl], avec.sB);
- MultiplyAdd(acc[VW2*_w+12], xlm[_kl], avec.sC);
- MultiplyAdd(acc[VW2*_w+13], xlm[_kl], avec.sD);
- MultiplyAdd(acc[VW2*_w+14], xlm[_kl], avec.sE);
- MultiplyAdd(acc[VW2*_w+15], xlm[_kl], avec.sF);
+ MultiplyAdd(acc2[VW2*_w+0], xlm[_kl], avec.s0);
+ MultiplyAdd(acc2[VW2*_w+1], xlm[_kl], avec.s1);
+ MultiplyAdd(acc2[VW2*_w+2], xlm[_kl], avec.s2);
+ MultiplyAdd(acc2[VW2*_w+3], xlm[_kl], avec.s3);
+ MultiplyAdd(acc2[VW2*_w+4], xlm[_kl], avec.s4);
+ MultiplyAdd(acc2[VW2*_w+5], xlm[_kl], avec.s5);
+ MultiplyAdd(acc2[VW2*_w+6], xlm[_kl], avec.s6);
+ MultiplyAdd(acc2[VW2*_w+7], xlm[_kl], avec.s7);
+ MultiplyAdd(acc2[VW2*_w+8], xlm[_kl], avec.s8);
+ MultiplyAdd(acc2[VW2*_w+9], xlm[_kl], avec.s9);
+ MultiplyAdd(acc2[VW2*_w+10], xlm[_kl], avec.sA);
+ MultiplyAdd(acc2[VW2*_w+11], xlm[_kl], avec.sB);
+ MultiplyAdd(acc2[VW2*_w+12], xlm[_kl], avec.sC);
+ MultiplyAdd(acc2[VW2*_w+13], xlm[_kl], avec.sD);
+ MultiplyAdd(acc2[VW2*_w+14], xlm[_kl], avec.sE);
+ MultiplyAdd(acc2[VW2*_w+15], xlm[_kl], avec.sF);
#endif
}
}
@@ -179,7 +179,7 @@ void XgemvFast(const int m, const int n,
for (int _w = 0; _w < WPT2; _w += 1) {
const int gid = WPT2*get_global_id(0) + _w;
real yval = ygm[gid*y_inc + y_offset];
- AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[_w], beta, yval);
+ AXPBY(ygm[gid*y_inc + y_offset], alpha, acc2[_w], beta, yval);
}
}
@@ -214,8 +214,8 @@ void XgemvFastRot(const int m, const int n,
__local real xlm[WPT3];
// Initializes the accumulation register
- real acc;
- SetToZero(acc);
+ real acc3;
+ SetToZero(acc3);
// Loops over tile-sized portions of the work
for (int kwg=0; kwg<n; kwg+=WPT3) {
@@ -280,7 +280,7 @@ void XgemvFastRot(const int m, const int n,
for (int _v = 0; _v < VW3; _v += 1) {
real aval = tile[lid_mod*VW3 + _v][lid_div * (WPT3/VW3) + _kl];
real xval = xlm[_kl*VW3 + _v];
- MultiplyAdd(acc, xval, aval);
+ MultiplyAdd(acc3, xval, aval);
}
}
@@ -291,7 +291,7 @@ void XgemvFastRot(const int m, const int n,
// Stores the final result
const int gid = get_global_id(0);
real yval = ygm[gid * y_inc + y_offset];
- AXPBY(ygm[gid * y_inc + y_offset], alpha, acc, beta, yval);
+ AXPBY(ygm[gid * y_inc + y_offset], alpha, acc3, beta, yval);
}
// =================================================================================================
diff --git a/test/correctness/misc/preprocessor.cpp b/test/correctness/misc/preprocessor.cpp
index 71b59c04..fa0d2ccc 100644
--- a/test/correctness/misc/preprocessor.cpp
+++ b/test/correctness/misc/preprocessor.cpp
@@ -60,6 +60,54 @@ bool TestDefines() {
const auto result1 = PreprocessKernelSource(source1);
return result1 == expected1;
}
+// =================================================================================================
+
+bool TestArrayToRegisterPromotion() {
+ const auto source1 =
+ R"(#define WPT 2
+inline void SetValues(int float, float values[WPT],
+ const float k) {
+ #pragma unroll
+ for (int i = 0; i < WPT; i += 1) {
+ values[i] = k + j;
+ }
+}
+__kernel void ExampleKernel() {
+ #pragma promote_to_registers
+ float values[WPT];
+ #pragma unroll
+ for (int i = 0; i < WPT; i += 1) {
+ values[i] = 0.0f;
+ }
+ SetValues(12.3f, values, -3.9f);
+}
+)";
+ const auto expected1 =
+ R"(#define WPT 2
+inline void SetValues(int float, float values_0, float values_1,
+ const float k) {
+ {
+ values_0 = k + j;
+ }
+ {
+ values_1 = k + j;
+ }
+}
+__kernel void ExampleKernel() {
+ float values_0;
+ float values_1;
+ {
+ values_0 = 0.0f;
+ }
+ {
+ values_1 = 0.0f;
+ }
+ SetValues(12.3f, values_0, values_1, -3.9f);
+}
+)";
+ const auto result1 = PreprocessKernelSource(source1);
+ return result1 == expected1;
+}
// =================================================================================================
@@ -110,6 +158,7 @@ size_t RunPreprocessor(int argc, char *argv[], const bool silent, const Precisio
// Basic tests
if (TestDefines()) { passed++; } else { errors++; }
+ if (TestArrayToRegisterPromotion()) { passed++; } else { errors++; }
// XAXPY
const auto xaxpy_sources =