diff options
-rw-r--r-- | .travis.yml | 16 | ||||
-rw-r--r-- | samples/cache.c | 2 | ||||
-rw-r--r-- | samples/dgemv.c | 2 | ||||
-rw-r--r-- | samples/haxpy.c | 2 | ||||
-rw-r--r-- | samples/sasum.c | 2 | ||||
-rw-r--r-- | samples/sgemm.c | 2 | ||||
-rw-r--r-- | samples/sgemm.cpp | 5 | ||||
-rwxr-xr-x | scripts/generator/generator.py | 2 | ||||
-rw-r--r-- | src/clblast.cpp | 121 | ||||
-rw-r--r-- | src/clblast_c.cpp | 2 | ||||
-rw-r--r-- | src/clpp11.hpp | 51 | ||||
-rw-r--r-- | src/routine.cpp | 15 | ||||
-rw-r--r-- | src/routines/common.hpp | 2 | ||||
-rw-r--r-- | src/utilities/clblast_exceptions.hpp | 2 | ||||
-rw-r--r-- | src/utilities/utilities.hpp | 2 | ||||
-rw-r--r-- | test/correctness/tester.cpp | 34 | ||||
-rw-r--r-- | test/correctness/tester.hpp | 4 | ||||
-rw-r--r-- | test/performance/client.hpp | 4 |
18 files changed, 160 insertions, 110 deletions
diff --git a/.travis.yml b/.travis.yml index 0465afa4..6a47bbd7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,14 +2,6 @@ language: cpp sudo: required dist: trusty -os: - - linux - - osx - -compiler: - - gcc - - clang - addons: apt: sources: @@ -19,6 +11,14 @@ addons: - cmake - ocl-icd-opencl-dev +matrix: + include: + - os: linux + compiler: gcc + - os: linux + compiler: clang + - os: osx + env: global: - CLBLAST_ROOT=${TRAVIS_BUILD_DIR}/bin/clblast diff --git a/samples/cache.c b/samples/cache.c index 40f2163f..980c7cf3 100644 --- a/samples/cache.c +++ b/samples/cache.c @@ -20,6 +20,8 @@ #include <string.h> #include <time.h> +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings + // Includes the CLBlast library (C interface) #include <clblast_c.h> diff --git a/samples/dgemv.c b/samples/dgemv.c index dc2fe7db..975cb7ac 100644 --- a/samples/dgemv.c +++ b/samples/dgemv.c @@ -19,6 +19,8 @@ #include <stdio.h> #include <string.h> +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings + // Includes the CLBlast library (C interface) #include <clblast_c.h> diff --git a/samples/haxpy.c b/samples/haxpy.c index 8e0833f8..4f2bb400 100644 --- a/samples/haxpy.c +++ b/samples/haxpy.c @@ -18,6 +18,8 @@ #include <stdio.h> #include <string.h> +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings + // Includes the CLBlast library (C interface) #include <clblast_c.h> diff --git a/samples/sasum.c b/samples/sasum.c index c285dd14..78377336 100644 --- a/samples/sasum.c +++ b/samples/sasum.c @@ -19,6 +19,8 @@ #include <stdio.h> #include <string.h> +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings + // Includes the CLBlast library (C interface) #include <clblast_c.h> diff --git a/samples/sgemm.c b/samples/sgemm.c index 132dad81..92f3057d 100644 --- a/samples/sgemm.c +++ b/samples/sgemm.c @@ -19,6 +19,8 @@ #include <stdio.h> #include <string.h> +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings + // Includes the CLBlast library (C interface) #include <clblast_c.h> diff --git a/samples/sgemm.cpp b/samples/sgemm.cpp index 401ecff8..b960865b 100644 --- a/samples/sgemm.cpp +++ b/samples/sgemm.cpp @@ -20,6 +20,9 @@ #include <chrono> #include <vector> +#define CL_USE_DEPRECATED_OPENCL_1_1_APIS // to disable deprecation warnings +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings + // Includes the C++ OpenCL API. If not yet available, it can be found here: // https://www.khronos.org/registry/cl/api/1.1/cl.hpp #include "cl.hpp" @@ -103,7 +106,7 @@ int main() { auto time_ms = std::chrono::duration<double,std::milli>(elapsed_time).count(); // Example completed. See "clblast.h" for status codes (0 -> success). - printf("Completed SGEMM in %.3lf ms with status %d\n", time_ms, status); + printf("Completed SGEMM in %.3lf ms with status %d\n", time_ms, static_cast<int>(status)); return 0; } diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 35d902b7..6591cbf7 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -42,7 +42,7 @@ FILES = [ "/src/clblast_netlib_c.cpp", ] HEADER_LINES = [117, 73, 118, 22, 29, 41, 65, 32] -FOOTER_LINES = [17, 80, 19, 18, 6, 6, 9, 2] +FOOTER_LINES = [17, 95, 19, 18, 6, 6, 9, 2] # Different possibilities for requirements ald_m = "The value of `a_ld` must be at least `m`." diff --git a/src/clblast.cpp b/src/clblast.cpp index 4bb4e0b3..e0f8add2 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -15,8 +15,8 @@ #include <string> -#include "clblast.h" #include "cache.hpp" +#include "clblast.h" // BLAS level-1 includes #include "routines/level1/xswap.hpp" @@ -2170,6 +2170,71 @@ StatusCode ClearCache() { return StatusCode::kSuccess; } +template <typename Real, typename Complex> +void FillCacheForPrecision(Queue &queue) { + try { + + // Runs all the level 1 set-up functions + Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr); + Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr); + Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr); + Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr); + Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr); + Xdot<Real>(queue, nullptr); + Xdotu<Complex>(queue, nullptr); + Xdotc<Complex>(queue, nullptr); + Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr); + Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr); + Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr); + Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr); + Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr); + Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr); + + // Runs all the level 2 set-up functions + Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr); + Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr); + Xhemv<Complex>(queue, nullptr); + Xhbmv<Complex>(queue, nullptr); + Xhpmv<Complex>(queue, nullptr); + Xsymv<Real>(queue, nullptr); + Xsbmv<Real>(queue, nullptr); + Xspmv<Real>(queue, nullptr); + Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr); + Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr); + Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr); + Xger<Real>(queue, nullptr); + Xgeru<Complex>(queue, nullptr); + Xgerc<Complex>(queue, nullptr); + Xher<Complex,Real>(queue, nullptr); + Xhpr<Complex,Real>(queue, nullptr); + Xher2<Complex>(queue, nullptr); + Xhpr2<Complex>(queue, nullptr); + Xsyr<Real>(queue, nullptr); + Xspr<Real>(queue, nullptr); + Xsyr2<Real>(queue, nullptr); + Xspr2<Real>(queue, nullptr); + + // Runs all the level 3 set-up functions + Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr); + Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr); + Xhemm<Complex>(queue, nullptr); + Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr); + Xherk<Complex,Real>(queue, nullptr); + Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr); + Xher2k<Complex,Real>(queue, nullptr); + Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr); + + // Runs all the non-BLAS set-up functions + Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr); + + } catch(const RuntimeErrorCode &e) { + if (e.status() != StatusCode::kNoDoublePrecision && + e.status() != StatusCode::kNoHalfPrecision) { + throw; + } + } +} + // Fills the cache with all binaries for a specific device // TODO: Add half-precision FP16 set-up calls StatusCode FillCache(const cl_device_id device) { @@ -2180,58 +2245,8 @@ StatusCode FillCache(const cl_device_id device) { auto context = Context(device_cpp); auto queue = Queue(context, device_cpp); - // Runs all the level 1 set-up functions - Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr); - Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr); - Xscal<float>(queue, nullptr); Xscal<double>(queue, nullptr); Xscal<float2>(queue, nullptr); Xscal<double2>(queue, nullptr); - Xcopy<float>(queue, nullptr); Xcopy<double>(queue, nullptr); Xcopy<float2>(queue, nullptr); Xcopy<double2>(queue, nullptr); - Xaxpy<float>(queue, nullptr); Xaxpy<double>(queue, nullptr); Xaxpy<float2>(queue, nullptr); Xaxpy<double2>(queue, nullptr); - Xdot<float>(queue, nullptr); Xdot<double>(queue, nullptr); - Xdotu<float2>(queue, nullptr); Xdotu<double2>(queue, nullptr); - Xdotc<float2>(queue, nullptr); Xdotc<double2>(queue, nullptr); - Xnrm2<float>(queue, nullptr); Xnrm2<double>(queue, nullptr); Xnrm2<float2>(queue, nullptr); Xnrm2<double2>(queue, nullptr); - Xasum<float>(queue, nullptr); Xasum<double>(queue, nullptr); Xasum<float2>(queue, nullptr); Xasum<double2>(queue, nullptr); - Xsum<float>(queue, nullptr); Xsum<double>(queue, nullptr); Xsum<float2>(queue, nullptr); Xsum<double2>(queue, nullptr); - Xamax<float>(queue, nullptr); Xamax<double>(queue, nullptr); Xamax<float2>(queue, nullptr); Xamax<double2>(queue, nullptr); - Xmax<float>(queue, nullptr); Xmax<double>(queue, nullptr); Xmax<float2>(queue, nullptr); Xmax<double2>(queue, nullptr); - Xmin<float>(queue, nullptr); Xmin<double>(queue, nullptr); Xmin<float2>(queue, nullptr); Xmin<double2>(queue, nullptr); - - // Runs all the level 2 set-up functions - Xgemv<float>(queue, nullptr); Xgemv<double>(queue, nullptr); Xgemv<float2>(queue, nullptr); Xgemv<double2>(queue, nullptr); - Xgbmv<float>(queue, nullptr); Xgbmv<double>(queue, nullptr); Xgbmv<float2>(queue, nullptr); Xgbmv<double2>(queue, nullptr); - Xhemv<float2>(queue, nullptr); Xhemv<double2>(queue, nullptr); - Xhbmv<float2>(queue, nullptr); Xhbmv<double2>(queue, nullptr); - Xhpmv<float2>(queue, nullptr); Xhpmv<double2>(queue, nullptr); - Xsymv<float>(queue, nullptr); Xsymv<double>(queue, nullptr); - Xsbmv<float>(queue, nullptr); Xsbmv<double>(queue, nullptr); - Xspmv<float>(queue, nullptr); Xspmv<double>(queue, nullptr); - Xtrmv<float>(queue, nullptr); Xtrmv<double>(queue, nullptr); Xtrmv<float2>(queue, nullptr); Xtrmv<double2>(queue, nullptr); - Xtbmv<float>(queue, nullptr); Xtbmv<double>(queue, nullptr); Xtbmv<float2>(queue, nullptr); Xtbmv<double2>(queue, nullptr); - Xtpmv<float>(queue, nullptr); Xtpmv<double>(queue, nullptr); Xtpmv<float2>(queue, nullptr); Xtpmv<double2>(queue, nullptr); - Xger<float>(queue, nullptr); Xger<double>(queue, nullptr); - Xgeru<float2>(queue, nullptr); Xgeru<double2>(queue, nullptr); - Xgerc<float2>(queue, nullptr); Xgerc<double2>(queue, nullptr); - Xher<float2,float>(queue, nullptr); Xher<double2,double>(queue, nullptr); - Xhpr<float2,float>(queue, nullptr); Xhpr<double2,double>(queue, nullptr); - Xher2<float2>(queue, nullptr); Xher2<double2>(queue, nullptr); - Xhpr2<float2>(queue, nullptr); Xhpr2<double2>(queue, nullptr); - Xsyr<float>(queue, nullptr); Xsyr<double>(queue, nullptr); - Xspr<float>(queue, nullptr); Xspr<double>(queue, nullptr); - Xsyr2<float>(queue, nullptr); Xsyr2<double>(queue, nullptr); - Xspr2<float>(queue, nullptr); Xspr2<double>(queue, nullptr); - - // Runs all the level 3 set-up functions - Xgemm<float>(queue, nullptr); Xgemm<double>(queue, nullptr); Xgemm<float2>(queue, nullptr); Xgemm<double2>(queue, nullptr); - Xsymm<float>(queue, nullptr); Xsymm<double>(queue, nullptr); Xsymm<float2>(queue, nullptr); Xsymm<double2>(queue, nullptr); - Xhemm<float2>(queue, nullptr); Xhemm<double2>(queue, nullptr); - Xsyrk<float>(queue, nullptr); Xsyrk<double>(queue, nullptr); Xsyrk<float2>(queue, nullptr); Xsyrk<double2>(queue, nullptr); - Xherk<float2,float>(queue, nullptr); Xherk<double2,double>(queue, nullptr); - Xsyr2k<float>(queue, nullptr); Xsyr2k<double>(queue, nullptr); Xsyr2k<float2>(queue, nullptr); Xsyr2k<double2>(queue, nullptr); - Xher2k<float2,float>(queue, nullptr); Xher2k<double2,double>(queue, nullptr); - Xtrmm<float>(queue, nullptr); Xtrmm<double>(queue, nullptr); Xtrmm<float2>(queue, nullptr); Xtrmm<double2>(queue, nullptr); - - // Runs all the level 3 set-up functions - Xomatcopy<float>(queue, nullptr); Xomatcopy<double>(queue, nullptr); Xomatcopy<float2>(queue, nullptr); Xomatcopy<double2>(queue, nullptr); + FillCacheForPrecision<float, float2>(queue); + FillCacheForPrecision<double, double2>(queue); } catch (...) { return DispatchException(); } return StatusCode::kSuccess; diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index 59e4cd16..e4f2b3ed 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -13,9 +13,9 @@ #include <string> +#include "utilities/utilities.hpp" #include "clblast_c.h" #include "clblast.h" -#include "utilities/utilities.hpp" // Shortcuts to the clblast namespace using float2 = clblast::float2; diff --git a/src/clpp11.hpp b/src/clpp11.hpp index 0383f53a..c984661c 100644 --- a/src/clpp11.hpp +++ b/src/clpp11.hpp @@ -333,7 +333,10 @@ class Context { // Regular constructor with memory management explicit Context(const Device &device): - context_(new cl_context, [](cl_context* c) { CheckErrorDtor(clReleaseContext(*c)); delete c; }) { + context_(new cl_context, [](cl_context* c) { + if (*c) { CheckErrorDtor(clReleaseContext(*c)); } + delete c; + }) { auto status = CL_SUCCESS; const cl_device_id dev = device(); *context_ = clCreateContext(nullptr, 1, &dev, nullptr, nullptr, &status); @@ -355,33 +358,37 @@ using ContextPointer = cl_context*; // Enumeration of build statuses of the run-time compilation process enum class BuildStatus { kSuccess, kError, kInvalid }; -// C++11 version of 'cl_program'. Additionally holds the program's source code. +// C++11 version of 'cl_program'. class Program { public: // Note that there is no constructor based on the regular OpenCL data-type because of extra state // Source-based constructor with memory management - explicit Program(const Context &context, std::string source): - program_(new cl_program, [](cl_program* p) { CheckErrorDtor(clReleaseProgram(*p)); delete p; }), - length_(source.length()), - source_(std::move(source)), - source_ptr_(&source_[0]) { + explicit Program(const Context &context, const std::string &source): + program_(new cl_program, [](cl_program* p) { + if (*p) { CheckErrorDtor(clReleaseProgram(*p)); } + delete p; + }) { + const char *source_ptr = &source[0]; + size_t length = source.length(); auto status = CL_SUCCESS; - *program_ = clCreateProgramWithSource(context(), 1, &source_ptr_, &length_, &status); + *program_ = clCreateProgramWithSource(context(), 1, &source_ptr, &length, &status); CLError::Check(status, "clCreateProgramWithSource"); } // Binary-based constructor with memory management - explicit Program(const Device &device, const Context &context, const std::string& binary): - program_(new cl_program, [](cl_program* p) { CheckErrorDtor(clReleaseProgram(*p)); delete p; }), - length_(binary.length()), - source_(binary), - source_ptr_(&source_[0]) { + explicit Program(const Device &device, const Context &context, const std::string &binary): + program_(new cl_program, [](cl_program* p) { + if (*p) { CheckErrorDtor(clReleaseProgram(*p)); } + delete p; + }) { + const char *binary_ptr = &binary[0]; + size_t length = binary.length(); auto status1 = CL_SUCCESS; auto status2 = CL_SUCCESS; const cl_device_id dev = device(); - *program_ = clCreateProgramWithBinary(context(), 1, &dev, &length_, - reinterpret_cast<const unsigned char**>(&source_ptr_), + *program_ = clCreateProgramWithBinary(context(), 1, &dev, &length, + reinterpret_cast<const unsigned char**>(&binary_ptr), &status1, &status2); CLError::Check(status1, "clCreateProgramWithBinary (binary status)"); CLError::Check(status2, "clCreateProgramWithBinary"); @@ -421,9 +428,6 @@ class Program { const cl_program& operator()() const { return *program_; } private: std::shared_ptr<cl_program> program_; - size_t length_; - std::string source_; // Note: the source can also be a binary or IR - const char* source_ptr_; }; // ================================================================================================= @@ -440,8 +444,10 @@ class Queue { // Regular constructor with memory management explicit Queue(const Context &context, const Device &device): - queue_(new cl_command_queue, [](cl_command_queue* s) { CheckErrorDtor(clReleaseCommandQueue(*s)); - delete s; }) { + queue_(new cl_command_queue, [](cl_command_queue* s) { + if (*s) { CheckErrorDtor(clReleaseCommandQueue(*s)); } + delete s; + }) { auto status = CL_SUCCESS; *queue_ = clCreateCommandQueue(context(), device(), CL_QUEUE_PROFILING_ENABLE, &status); CLError::Check(status, "clCreateCommandQueue"); @@ -665,7 +671,10 @@ class Kernel { // Regular constructor with memory management explicit Kernel(const Program &program, const std::string &name): - kernel_(new cl_kernel, [](cl_kernel* k) { CheckErrorDtor(clReleaseKernel(*k)); delete k; }) { + kernel_(new cl_kernel, [](cl_kernel* k) { + if (*k) { CheckErrorDtor(clReleaseKernel(*k)); } + delete k; + }) { auto status = CL_SUCCESS; *kernel_ = clCreateKernel(program(), name.c_str(), &status); CLError::Check(status, "clCreateKernel"); diff --git a/src/routine.cpp b/src/routine.cpp index acafb0d2..d5a6b589 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -52,24 +52,21 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, auto program = Program(device_, context_, binary); program.Build(device_, options); StoreProgramToCache(program, context_, precision_, routine_name_); + return; } // Otherwise, the kernel will be compiled and program will be built. Both the binary and the // program will be added to the cache. // Inspects whether or not cl_khr_fp64 is supported in case of double precision - const auto extensions = device_.Capabilities(); - if (precision_ == Precision::kDouble || precision_ == Precision::kComplexDouble) { - if (extensions.find(kKhronosDoublePrecision) == std::string::npos) { - throw RuntimeErrorCode(StatusCode::kNoDoublePrecision); - } + if ((precision_ == Precision::kDouble && !PrecisionSupported<double>(device_)) || + (precision_ == Precision::kComplexDouble && !PrecisionSupported<double2>(device_))) { + throw RuntimeErrorCode(StatusCode::kNoDoublePrecision); } // As above, but for cl_khr_fp16 (half precision) - if (precision_ == Precision::kHalf) { - if (extensions.find(kKhronosHalfPrecision) == std::string::npos) { - throw RuntimeErrorCode(StatusCode::kNoHalfPrecision); - } + if (precision_ == Precision::kHalf && !PrecisionSupported<half>(device_)) { + throw RuntimeErrorCode(StatusCode::kNoHalfPrecision); } // Collects the parameters for this device in the form of defines, and adds the precision diff --git a/src/routines/common.hpp b/src/routines/common.hpp index 53ca6355..7c211c0d 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -19,8 +19,8 @@ #include <string> #include <vector> -#include "clblast.h" #include "clpp11.hpp" +#include "clblast.h" #include "database/database.hpp" namespace clblast { diff --git a/src/utilities/clblast_exceptions.hpp b/src/utilities/clblast_exceptions.hpp index f3c7b9a3..0d0033b6 100644 --- a/src/utilities/clblast_exceptions.hpp +++ b/src/utilities/clblast_exceptions.hpp @@ -16,8 +16,8 @@ #ifndef CLBLAST_EXCEPTIONS_H_ #define CLBLAST_EXCEPTIONS_H_ -#include "clblast.h" #include "clpp11.hpp" +#include "clblast.h" namespace clblast { // ================================================================================================= diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp index 20587bd4..a1d4e2db 100644 --- a/src/utilities/utilities.hpp +++ b/src/utilities/utilities.hpp @@ -21,9 +21,9 @@ #include <functional> #include <complex> +#include "clpp11.hpp" #include "clblast.h" #include "clblast_half.h" -#include "clpp11.hpp" #include "utilities/clblast_exceptions.hpp" #include "utilities/msvc.hpp" diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp index c449b09d..efe49811 100644 --- a/test/correctness/tester.cpp +++ b/test/correctness/tester.cpp @@ -248,8 +248,29 @@ template <typename T, typename U> void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status, const Arguments<U> &args) { + // Either an OpenCL or CLBlast internal error occurred, fail the test immediately + // NOTE: the OpenCL error codes grow downwards without any declared lower bound, hence the magic + // number. The last error code is atm around -70, but -500 is chosen to be on the safe side. + if (clblast_status != StatusCode::kSuccess && + (clblast_status > static_cast<StatusCode>(-500) /* matches OpenCL errors (see above) */ || + clblast_status < StatusCode::kNotImplemented) /* matches CLBlast internal errors */) { + PrintTestResult(kErrorStatus); + ReportError({StatusCode::kSuccess, clblast_status, kStatusError, args}); + if (verbose_) { + fprintf(stdout, "\n"); + PrintErrorLog({{StatusCode::kSuccess, clblast_status, kStatusError, args}}); + fprintf(stdout, " "); + } + } + + // Routine is not implemented + else if (clblast_status == StatusCode::kNotImplemented) { + PrintTestResult(kSkippedCompilation); + ReportSkipped(); + } + // Cannot compare error codes against a library other than clBLAS - if (compare_cblas_) { + else if (compare_cblas_) { PrintTestResult(kUnsupportedReference); ReportSkipped(); } @@ -267,13 +288,6 @@ void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCod ReportSkipped(); } - // Could not compile the CLBlast kernel properly - else if (clblast_status == StatusCode::kOpenCLBuildProgramFailure || - clblast_status == StatusCode::kNotImplemented) { - PrintTestResult(kSkippedCompilation); - ReportSkipped(); - } - // Error occurred else { PrintTestResult(kErrorStatus); @@ -388,7 +402,9 @@ void Tester<T,U>::PrintErrorLog(const std::vector<ErrorLogEntry> &error_log) { fprintf(stdout, " Error rate %.1lf%%: ", entry.error_percentage); } else { - fprintf(stdout, " Status code %d (expected %d): ", entry.status_found, entry.status_expect); + fprintf(stdout, " Status code %d (expected %d): ", + static_cast<int>(entry.status_found), + static_cast<int>(entry.status_expect)); } fprintf(stdout, "%s\n", GetOptionsString(entry.args).c_str()); } diff --git a/test/correctness/tester.hpp b/test/correctness/tester.hpp index d8462cef..113f03ef 100644 --- a/test/correctness/tester.hpp +++ b/test/correctness/tester.hpp @@ -22,14 +22,14 @@ #include <vector> #include <memory> +#include "utilities/utilities.hpp" + // The libraries #ifdef CLBLAST_REF_CLBLAS #include <clBLAS.h> #endif #include "clblast.h" -#include "utilities/utilities.hpp" - namespace clblast { // ================================================================================================= diff --git a/test/performance/client.hpp b/test/performance/client.hpp index 4554c67f..4b3e17c7 100644 --- a/test/performance/client.hpp +++ b/test/performance/client.hpp @@ -25,14 +25,14 @@ #include <vector> #include <utility> +#include "utilities/utilities.hpp" + // The libraries to test #ifdef CLBLAST_REF_CLBLAS #include <clBLAS.h> #endif #include "clblast.h" -#include "utilities/utilities.hpp" - namespace clblast { // ================================================================================================= |