summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/kernel_preprocessor.cpp12
-rw-r--r--test/correctness/misc/preprocessor.cpp20
2 files changed, 28 insertions, 4 deletions
diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp
index 1ae83c2d..b6c0b398 100644
--- a/src/kernel_preprocessor.cpp
+++ b/src/kernel_preprocessor.cpp
@@ -33,14 +33,14 @@ namespace clblast {
// =================================================================================================
void RaiseError(const std::string& source_line, const std::string& exception_message) {
- printf("Error in source line: %s\n", source_line.c_str());
+ printf("[OpenCL pre-processor] Error in source line: %s\n", source_line.c_str());
throw Error<std::runtime_error>(exception_message);
}
// =================================================================================================
bool HasOnlyDigits(const std::string& str) {
- return str.find_first_not_of("0123456789") == std::string::npos;
+ return str.find_first_not_of(" 0123456789") == std::string::npos;
}
// Converts a string to an integer. The source line is printed in case an exception is raised.
@@ -161,8 +161,12 @@ std::vector<std::string> PreprocessDefinesAndComments(const std::string& source,
const auto value_pos = define.find(" ");
const auto value = define.substr(value_pos + 1);
const auto name = define.substr(0, value_pos);
- defines.emplace(name, std::stoi(value));
- //continue;
+ if (HasOnlyDigits(value)) {
+ defines.emplace(name, std::stoi(value));
+ }
+ else {
+ printf("'%s'\n", value.c_str());
+ }
}
// Detect #ifndef blocks
diff --git a/test/correctness/misc/preprocessor.cpp b/test/correctness/misc/preprocessor.cpp
index 5df27111..bcc65700 100644
--- a/test/correctness/misc/preprocessor.cpp
+++ b/test/correctness/misc/preprocessor.cpp
@@ -90,11 +90,31 @@ size_t RunPreprocessor(int argc, char *argv[], const bool silent,
"#define WPT1 2\n"
"#define WPT2 2\n"
"#define WPT3 2\n"
+ "#define UNROLL1 4\n"
#include "../src/kernels/level2/xgemv.opencl"
#include "../src/kernels/level2/xgemv_fast.opencl"
;
if (TestKernel(device, context, "XgemvFast", xgemv_sources, precision)) { passed++; } else { errors++; }
+ // CopyFast
+ const auto copy_fast_sources =
+ "#define COPY_WPT 2\n"
+ #include "../src/kernels/level3/level3.opencl"
+ #include "../src/kernels/level3/copy_fast.opencl"
+ ;
+ if (TestKernel(device, context, "CopyMatrixFast", copy_fast_sources, precision)) { passed++; } else { errors++; }
+
+ // CopyPad
+ const auto copy_pad_sources =
+ "#define PAD_WPTX 2\n"
+ "#define PAD_WPTY 2\n"
+#include "../src/kernels/level3/level3.opencl"
+#include "../src/kernels/level3/copy_pad.opencl"
+ ;
+ if (TestKernel(device, context, "CopyPadMatrix", copy_pad_sources, precision)) { passed++; } else { errors++; }
+
+
+
// Prints and returns the statistics
std::cout << std::endl;
std::cout << " " << passed << " test(s) passed" << std::endl;