summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-01-24 09:10:35 +0100
committerGitHub <noreply@github.com>2017-01-24 09:10:35 +0100
commite943fe77d64f42ed1e57c9919de8ca6787760f2b (patch)
tree9da420f6259d2e7a5aafffd530d6a84dea8402e3
parent2e4f6e16098d36d0572769a1092d1e54cdb6d4ea (diff)
parent46a59eb8821e3c92db7be347fca099405246d9ec (diff)
Merge pull request #131 from intelfx/misc
Assorted minor fixes
-rw-r--r--.travis.yml16
-rw-r--r--samples/cache.c2
-rw-r--r--samples/dgemv.c2
-rw-r--r--samples/haxpy.c2
-rw-r--r--samples/sasum.c2
-rw-r--r--samples/sgemm.c2
-rw-r--r--samples/sgemm.cpp5
-rwxr-xr-xscripts/generator/generator.py2
-rw-r--r--src/clblast.cpp121
-rw-r--r--src/clblast_c.cpp2
-rw-r--r--src/clpp11.hpp51
-rw-r--r--src/routine.cpp15
-rw-r--r--src/routines/common.hpp2
-rw-r--r--src/utilities/clblast_exceptions.hpp2
-rw-r--r--src/utilities/utilities.hpp2
-rw-r--r--test/correctness/tester.cpp34
-rw-r--r--test/correctness/tester.hpp4
-rw-r--r--test/performance/client.hpp4
18 files changed, 160 insertions, 110 deletions
diff --git a/.travis.yml b/.travis.yml
index 0465afa4..6a47bbd7 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -2,14 +2,6 @@ language: cpp
sudo: required
dist: trusty
-os:
- - linux
- - osx
-
-compiler:
- - gcc
- - clang
-
addons:
apt:
sources:
@@ -19,6 +11,14 @@ addons:
- cmake
- ocl-icd-opencl-dev
+matrix:
+ include:
+ - os: linux
+ compiler: gcc
+ - os: linux
+ compiler: clang
+ - os: osx
+
env:
global:
- CLBLAST_ROOT=${TRAVIS_BUILD_DIR}/bin/clblast
diff --git a/samples/cache.c b/samples/cache.c
index 40f2163f..980c7cf3 100644
--- a/samples/cache.c
+++ b/samples/cache.c
@@ -20,6 +20,8 @@
#include <string.h>
#include <time.h>
+#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings
+
// Includes the CLBlast library (C interface)
#include <clblast_c.h>
diff --git a/samples/dgemv.c b/samples/dgemv.c
index dc2fe7db..975cb7ac 100644
--- a/samples/dgemv.c
+++ b/samples/dgemv.c
@@ -19,6 +19,8 @@
#include <stdio.h>
#include <string.h>
+#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings
+
// Includes the CLBlast library (C interface)
#include <clblast_c.h>
diff --git a/samples/haxpy.c b/samples/haxpy.c
index 8e0833f8..4f2bb400 100644
--- a/samples/haxpy.c
+++ b/samples/haxpy.c
@@ -18,6 +18,8 @@
#include <stdio.h>
#include <string.h>
+#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings
+
// Includes the CLBlast library (C interface)
#include <clblast_c.h>
diff --git a/samples/sasum.c b/samples/sasum.c
index c285dd14..78377336 100644
--- a/samples/sasum.c
+++ b/samples/sasum.c
@@ -19,6 +19,8 @@
#include <stdio.h>
#include <string.h>
+#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings
+
// Includes the CLBlast library (C interface)
#include <clblast_c.h>
diff --git a/samples/sgemm.c b/samples/sgemm.c
index 132dad81..92f3057d 100644
--- a/samples/sgemm.c
+++ b/samples/sgemm.c
@@ -19,6 +19,8 @@
#include <stdio.h>
#include <string.h>
+#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings
+
// Includes the CLBlast library (C interface)
#include <clblast_c.h>
diff --git a/samples/sgemm.cpp b/samples/sgemm.cpp
index 401ecff8..b960865b 100644
--- a/samples/sgemm.cpp
+++ b/samples/sgemm.cpp
@@ -20,6 +20,9 @@
#include <chrono>
#include <vector>
+#define CL_USE_DEPRECATED_OPENCL_1_1_APIS // to disable deprecation warnings
+#define CL_USE_DEPRECATED_OPENCL_1_2_APIS // to disable deprecation warnings
+
// Includes the C++ OpenCL API. If not yet available, it can be found here:
// https://www.khronos.org/registry/cl/api/1.1/cl.hpp
#include "cl.hpp"
@@ -103,7 +106,7 @@ int main() {
auto time_ms = std::chrono::duration<double,std::milli>(elapsed_time).count();
// Example completed. See "clblast.h" for status codes (0 -> success).
- printf("Completed SGEMM in %.3lf ms with status %d\n", time_ms, status);
+ printf("Completed SGEMM in %.3lf ms with status %d\n", time_ms, static_cast<int>(status));
return 0;
}
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 35d902b7..6591cbf7 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -42,7 +42,7 @@ FILES = [
"/src/clblast_netlib_c.cpp",
]
HEADER_LINES = [117, 73, 118, 22, 29, 41, 65, 32]
-FOOTER_LINES = [17, 80, 19, 18, 6, 6, 9, 2]
+FOOTER_LINES = [17, 95, 19, 18, 6, 6, 9, 2]
# Different possibilities for requirements
ald_m = "The value of `a_ld` must be at least `m`."
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 4bb4e0b3..e0f8add2 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -15,8 +15,8 @@
#include <string>
-#include "clblast.h"
#include "cache.hpp"
+#include "clblast.h"
// BLAS level-1 includes
#include "routines/level1/xswap.hpp"
@@ -2170,6 +2170,71 @@ StatusCode ClearCache() {
return StatusCode::kSuccess;
}
+template <typename Real, typename Complex>
+void FillCacheForPrecision(Queue &queue) {
+ try {
+
+ // Runs all the level 1 set-up functions
+ Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
+ Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
+ Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr);
+ Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr);
+ Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr);
+ Xdot<Real>(queue, nullptr);
+ Xdotu<Complex>(queue, nullptr);
+ Xdotc<Complex>(queue, nullptr);
+ Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr);
+ Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr);
+ Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr);
+ Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr);
+ Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr);
+ Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr);
+
+ // Runs all the level 2 set-up functions
+ Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr);
+ Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr);
+ Xhemv<Complex>(queue, nullptr);
+ Xhbmv<Complex>(queue, nullptr);
+ Xhpmv<Complex>(queue, nullptr);
+ Xsymv<Real>(queue, nullptr);
+ Xsbmv<Real>(queue, nullptr);
+ Xspmv<Real>(queue, nullptr);
+ Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr);
+ Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr);
+ Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr);
+ Xger<Real>(queue, nullptr);
+ Xgeru<Complex>(queue, nullptr);
+ Xgerc<Complex>(queue, nullptr);
+ Xher<Complex,Real>(queue, nullptr);
+ Xhpr<Complex,Real>(queue, nullptr);
+ Xher2<Complex>(queue, nullptr);
+ Xhpr2<Complex>(queue, nullptr);
+ Xsyr<Real>(queue, nullptr);
+ Xspr<Real>(queue, nullptr);
+ Xsyr2<Real>(queue, nullptr);
+ Xspr2<Real>(queue, nullptr);
+
+ // Runs all the level 3 set-up functions
+ Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr);
+ Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr);
+ Xhemm<Complex>(queue, nullptr);
+ Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr);
+ Xherk<Complex,Real>(queue, nullptr);
+ Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr);
+ Xher2k<Complex,Real>(queue, nullptr);
+ Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr);
+
+ // Runs all the non-BLAS set-up functions
+ Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr);
+
+ } catch(const RuntimeErrorCode &e) {
+ if (e.status() != StatusCode::kNoDoublePrecision &&
+ e.status() != StatusCode::kNoHalfPrecision) {
+ throw;
+ }
+ }
+}
+
// Fills the cache with all binaries for a specific device
// TODO: Add half-precision FP16 set-up calls
StatusCode FillCache(const cl_device_id device) {
@@ -2180,58 +2245,8 @@ StatusCode FillCache(const cl_device_id device) {
auto context = Context(device_cpp);
auto queue = Queue(context, device_cpp);
- // Runs all the level 1 set-up functions
- Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr);
- Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr);
- Xscal<float>(queue, nullptr); Xscal<double>(queue, nullptr); Xscal<float2>(queue, nullptr); Xscal<double2>(queue, nullptr);
- Xcopy<float>(queue, nullptr); Xcopy<double>(queue, nullptr); Xcopy<float2>(queue, nullptr); Xcopy<double2>(queue, nullptr);
- Xaxpy<float>(queue, nullptr); Xaxpy<double>(queue, nullptr); Xaxpy<float2>(queue, nullptr); Xaxpy<double2>(queue, nullptr);
- Xdot<float>(queue, nullptr); Xdot<double>(queue, nullptr);
- Xdotu<float2>(queue, nullptr); Xdotu<double2>(queue, nullptr);
- Xdotc<float2>(queue, nullptr); Xdotc<double2>(queue, nullptr);
- Xnrm2<float>(queue, nullptr); Xnrm2<double>(queue, nullptr); Xnrm2<float2>(queue, nullptr); Xnrm2<double2>(queue, nullptr);
- Xasum<float>(queue, nullptr); Xasum<double>(queue, nullptr); Xasum<float2>(queue, nullptr); Xasum<double2>(queue, nullptr);
- Xsum<float>(queue, nullptr); Xsum<double>(queue, nullptr); Xsum<float2>(queue, nullptr); Xsum<double2>(queue, nullptr);
- Xamax<float>(queue, nullptr); Xamax<double>(queue, nullptr); Xamax<float2>(queue, nullptr); Xamax<double2>(queue, nullptr);
- Xmax<float>(queue, nullptr); Xmax<double>(queue, nullptr); Xmax<float2>(queue, nullptr); Xmax<double2>(queue, nullptr);
- Xmin<float>(queue, nullptr); Xmin<double>(queue, nullptr); Xmin<float2>(queue, nullptr); Xmin<double2>(queue, nullptr);
-
- // Runs all the level 2 set-up functions
- Xgemv<float>(queue, nullptr); Xgemv<double>(queue, nullptr); Xgemv<float2>(queue, nullptr); Xgemv<double2>(queue, nullptr);
- Xgbmv<float>(queue, nullptr); Xgbmv<double>(queue, nullptr); Xgbmv<float2>(queue, nullptr); Xgbmv<double2>(queue, nullptr);
- Xhemv<float2>(queue, nullptr); Xhemv<double2>(queue, nullptr);
- Xhbmv<float2>(queue, nullptr); Xhbmv<double2>(queue, nullptr);
- Xhpmv<float2>(queue, nullptr); Xhpmv<double2>(queue, nullptr);
- Xsymv<float>(queue, nullptr); Xsymv<double>(queue, nullptr);
- Xsbmv<float>(queue, nullptr); Xsbmv<double>(queue, nullptr);
- Xspmv<float>(queue, nullptr); Xspmv<double>(queue, nullptr);
- Xtrmv<float>(queue, nullptr); Xtrmv<double>(queue, nullptr); Xtrmv<float2>(queue, nullptr); Xtrmv<double2>(queue, nullptr);
- Xtbmv<float>(queue, nullptr); Xtbmv<double>(queue, nullptr); Xtbmv<float2>(queue, nullptr); Xtbmv<double2>(queue, nullptr);
- Xtpmv<float>(queue, nullptr); Xtpmv<double>(queue, nullptr); Xtpmv<float2>(queue, nullptr); Xtpmv<double2>(queue, nullptr);
- Xger<float>(queue, nullptr); Xger<double>(queue, nullptr);
- Xgeru<float2>(queue, nullptr); Xgeru<double2>(queue, nullptr);
- Xgerc<float2>(queue, nullptr); Xgerc<double2>(queue, nullptr);
- Xher<float2,float>(queue, nullptr); Xher<double2,double>(queue, nullptr);
- Xhpr<float2,float>(queue, nullptr); Xhpr<double2,double>(queue, nullptr);
- Xher2<float2>(queue, nullptr); Xher2<double2>(queue, nullptr);
- Xhpr2<float2>(queue, nullptr); Xhpr2<double2>(queue, nullptr);
- Xsyr<float>(queue, nullptr); Xsyr<double>(queue, nullptr);
- Xspr<float>(queue, nullptr); Xspr<double>(queue, nullptr);
- Xsyr2<float>(queue, nullptr); Xsyr2<double>(queue, nullptr);
- Xspr2<float>(queue, nullptr); Xspr2<double>(queue, nullptr);
-
- // Runs all the level 3 set-up functions
- Xgemm<float>(queue, nullptr); Xgemm<double>(queue, nullptr); Xgemm<float2>(queue, nullptr); Xgemm<double2>(queue, nullptr);
- Xsymm<float>(queue, nullptr); Xsymm<double>(queue, nullptr); Xsymm<float2>(queue, nullptr); Xsymm<double2>(queue, nullptr);
- Xhemm<float2>(queue, nullptr); Xhemm<double2>(queue, nullptr);
- Xsyrk<float>(queue, nullptr); Xsyrk<double>(queue, nullptr); Xsyrk<float2>(queue, nullptr); Xsyrk<double2>(queue, nullptr);
- Xherk<float2,float>(queue, nullptr); Xherk<double2,double>(queue, nullptr);
- Xsyr2k<float>(queue, nullptr); Xsyr2k<double>(queue, nullptr); Xsyr2k<float2>(queue, nullptr); Xsyr2k<double2>(queue, nullptr);
- Xher2k<float2,float>(queue, nullptr); Xher2k<double2,double>(queue, nullptr);
- Xtrmm<float>(queue, nullptr); Xtrmm<double>(queue, nullptr); Xtrmm<float2>(queue, nullptr); Xtrmm<double2>(queue, nullptr);
-
- // Runs all the level 3 set-up functions
- Xomatcopy<float>(queue, nullptr); Xomatcopy<double>(queue, nullptr); Xomatcopy<float2>(queue, nullptr); Xomatcopy<double2>(queue, nullptr);
+ FillCacheForPrecision<float, float2>(queue);
+ FillCacheForPrecision<double, double2>(queue);
} catch (...) { return DispatchException(); }
return StatusCode::kSuccess;
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index 59e4cd16..e4f2b3ed 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -13,9 +13,9 @@
#include <string>
+#include "utilities/utilities.hpp"
#include "clblast_c.h"
#include "clblast.h"
-#include "utilities/utilities.hpp"
// Shortcuts to the clblast namespace
using float2 = clblast::float2;
diff --git a/src/clpp11.hpp b/src/clpp11.hpp
index 0383f53a..c984661c 100644
--- a/src/clpp11.hpp
+++ b/src/clpp11.hpp
@@ -333,7 +333,10 @@ class Context {
// Regular constructor with memory management
explicit Context(const Device &device):
- context_(new cl_context, [](cl_context* c) { CheckErrorDtor(clReleaseContext(*c)); delete c; }) {
+ context_(new cl_context, [](cl_context* c) {
+ if (*c) { CheckErrorDtor(clReleaseContext(*c)); }
+ delete c;
+ }) {
auto status = CL_SUCCESS;
const cl_device_id dev = device();
*context_ = clCreateContext(nullptr, 1, &dev, nullptr, nullptr, &status);
@@ -355,33 +358,37 @@ using ContextPointer = cl_context*;
// Enumeration of build statuses of the run-time compilation process
enum class BuildStatus { kSuccess, kError, kInvalid };
-// C++11 version of 'cl_program'. Additionally holds the program's source code.
+// C++11 version of 'cl_program'.
class Program {
public:
// Note that there is no constructor based on the regular OpenCL data-type because of extra state
// Source-based constructor with memory management
- explicit Program(const Context &context, std::string source):
- program_(new cl_program, [](cl_program* p) { CheckErrorDtor(clReleaseProgram(*p)); delete p; }),
- length_(source.length()),
- source_(std::move(source)),
- source_ptr_(&source_[0]) {
+ explicit Program(const Context &context, const std::string &source):
+ program_(new cl_program, [](cl_program* p) {
+ if (*p) { CheckErrorDtor(clReleaseProgram(*p)); }
+ delete p;
+ }) {
+ const char *source_ptr = &source[0];
+ size_t length = source.length();
auto status = CL_SUCCESS;
- *program_ = clCreateProgramWithSource(context(), 1, &source_ptr_, &length_, &status);
+ *program_ = clCreateProgramWithSource(context(), 1, &source_ptr, &length, &status);
CLError::Check(status, "clCreateProgramWithSource");
}
// Binary-based constructor with memory management
- explicit Program(const Device &device, const Context &context, const std::string& binary):
- program_(new cl_program, [](cl_program* p) { CheckErrorDtor(clReleaseProgram(*p)); delete p; }),
- length_(binary.length()),
- source_(binary),
- source_ptr_(&source_[0]) {
+ explicit Program(const Device &device, const Context &context, const std::string &binary):
+ program_(new cl_program, [](cl_program* p) {
+ if (*p) { CheckErrorDtor(clReleaseProgram(*p)); }
+ delete p;
+ }) {
+ const char *binary_ptr = &binary[0];
+ size_t length = binary.length();
auto status1 = CL_SUCCESS;
auto status2 = CL_SUCCESS;
const cl_device_id dev = device();
- *program_ = clCreateProgramWithBinary(context(), 1, &dev, &length_,
- reinterpret_cast<const unsigned char**>(&source_ptr_),
+ *program_ = clCreateProgramWithBinary(context(), 1, &dev, &length,
+ reinterpret_cast<const unsigned char**>(&binary_ptr),
&status1, &status2);
CLError::Check(status1, "clCreateProgramWithBinary (binary status)");
CLError::Check(status2, "clCreateProgramWithBinary");
@@ -421,9 +428,6 @@ class Program {
const cl_program& operator()() const { return *program_; }
private:
std::shared_ptr<cl_program> program_;
- size_t length_;
- std::string source_; // Note: the source can also be a binary or IR
- const char* source_ptr_;
};
// =================================================================================================
@@ -440,8 +444,10 @@ class Queue {
// Regular constructor with memory management
explicit Queue(const Context &context, const Device &device):
- queue_(new cl_command_queue, [](cl_command_queue* s) { CheckErrorDtor(clReleaseCommandQueue(*s));
- delete s; }) {
+ queue_(new cl_command_queue, [](cl_command_queue* s) {
+ if (*s) { CheckErrorDtor(clReleaseCommandQueue(*s)); }
+ delete s;
+ }) {
auto status = CL_SUCCESS;
*queue_ = clCreateCommandQueue(context(), device(), CL_QUEUE_PROFILING_ENABLE, &status);
CLError::Check(status, "clCreateCommandQueue");
@@ -665,7 +671,10 @@ class Kernel {
// Regular constructor with memory management
explicit Kernel(const Program &program, const std::string &name):
- kernel_(new cl_kernel, [](cl_kernel* k) { CheckErrorDtor(clReleaseKernel(*k)); delete k; }) {
+ kernel_(new cl_kernel, [](cl_kernel* k) {
+ if (*k) { CheckErrorDtor(clReleaseKernel(*k)); }
+ delete k;
+ }) {
auto status = CL_SUCCESS;
*kernel_ = clCreateKernel(program(), name.c_str(), &status);
CLError::Check(status, "clCreateKernel");
diff --git a/src/routine.cpp b/src/routine.cpp
index acafb0d2..d5a6b589 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -52,24 +52,21 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
auto program = Program(device_, context_, binary);
program.Build(device_, options);
StoreProgramToCache(program, context_, precision_, routine_name_);
+ return;
}
// Otherwise, the kernel will be compiled and program will be built. Both the binary and the
// program will be added to the cache.
// Inspects whether or not cl_khr_fp64 is supported in case of double precision
- const auto extensions = device_.Capabilities();
- if (precision_ == Precision::kDouble || precision_ == Precision::kComplexDouble) {
- if (extensions.find(kKhronosDoublePrecision) == std::string::npos) {
- throw RuntimeErrorCode(StatusCode::kNoDoublePrecision);
- }
+ if ((precision_ == Precision::kDouble && !PrecisionSupported<double>(device_)) ||
+ (precision_ == Precision::kComplexDouble && !PrecisionSupported<double2>(device_))) {
+ throw RuntimeErrorCode(StatusCode::kNoDoublePrecision);
}
// As above, but for cl_khr_fp16 (half precision)
- if (precision_ == Precision::kHalf) {
- if (extensions.find(kKhronosHalfPrecision) == std::string::npos) {
- throw RuntimeErrorCode(StatusCode::kNoHalfPrecision);
- }
+ if (precision_ == Precision::kHalf && !PrecisionSupported<half>(device_)) {
+ throw RuntimeErrorCode(StatusCode::kNoHalfPrecision);
}
// Collects the parameters for this device in the form of defines, and adds the precision
diff --git a/src/routines/common.hpp b/src/routines/common.hpp
index 53ca6355..7c211c0d 100644
--- a/src/routines/common.hpp
+++ b/src/routines/common.hpp
@@ -19,8 +19,8 @@
#include <string>
#include <vector>
-#include "clblast.h"
#include "clpp11.hpp"
+#include "clblast.h"
#include "database/database.hpp"
namespace clblast {
diff --git a/src/utilities/clblast_exceptions.hpp b/src/utilities/clblast_exceptions.hpp
index f3c7b9a3..0d0033b6 100644
--- a/src/utilities/clblast_exceptions.hpp
+++ b/src/utilities/clblast_exceptions.hpp
@@ -16,8 +16,8 @@
#ifndef CLBLAST_EXCEPTIONS_H_
#define CLBLAST_EXCEPTIONS_H_
-#include "clblast.h"
#include "clpp11.hpp"
+#include "clblast.h"
namespace clblast {
// =================================================================================================
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index 20587bd4..a1d4e2db 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -21,9 +21,9 @@
#include <functional>
#include <complex>
+#include "clpp11.hpp"
#include "clblast.h"
#include "clblast_half.h"
-#include "clpp11.hpp"
#include "utilities/clblast_exceptions.hpp"
#include "utilities/msvc.hpp"
diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp
index c449b09d..efe49811 100644
--- a/test/correctness/tester.cpp
+++ b/test/correctness/tester.cpp
@@ -248,8 +248,29 @@ template <typename T, typename U>
void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCode clblast_status,
const Arguments<U> &args) {
+ // Either an OpenCL or CLBlast internal error occurred, fail the test immediately
+ // NOTE: the OpenCL error codes grow downwards without any declared lower bound, hence the magic
+ // number. The last error code is atm around -70, but -500 is chosen to be on the safe side.
+ if (clblast_status != StatusCode::kSuccess &&
+ (clblast_status > static_cast<StatusCode>(-500) /* matches OpenCL errors (see above) */ ||
+ clblast_status < StatusCode::kNotImplemented) /* matches CLBlast internal errors */) {
+ PrintTestResult(kErrorStatus);
+ ReportError({StatusCode::kSuccess, clblast_status, kStatusError, args});
+ if (verbose_) {
+ fprintf(stdout, "\n");
+ PrintErrorLog({{StatusCode::kSuccess, clblast_status, kStatusError, args}});
+ fprintf(stdout, " ");
+ }
+ }
+
+ // Routine is not implemented
+ else if (clblast_status == StatusCode::kNotImplemented) {
+ PrintTestResult(kSkippedCompilation);
+ ReportSkipped();
+ }
+
// Cannot compare error codes against a library other than clBLAS
- if (compare_cblas_) {
+ else if (compare_cblas_) {
PrintTestResult(kUnsupportedReference);
ReportSkipped();
}
@@ -267,13 +288,6 @@ void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCod
ReportSkipped();
}
- // Could not compile the CLBlast kernel properly
- else if (clblast_status == StatusCode::kOpenCLBuildProgramFailure ||
- clblast_status == StatusCode::kNotImplemented) {
- PrintTestResult(kSkippedCompilation);
- ReportSkipped();
- }
-
// Error occurred
else {
PrintTestResult(kErrorStatus);
@@ -388,7 +402,9 @@ void Tester<T,U>::PrintErrorLog(const std::vector<ErrorLogEntry> &error_log) {
fprintf(stdout, " Error rate %.1lf%%: ", entry.error_percentage);
}
else {
- fprintf(stdout, " Status code %d (expected %d): ", entry.status_found, entry.status_expect);
+ fprintf(stdout, " Status code %d (expected %d): ",
+ static_cast<int>(entry.status_found),
+ static_cast<int>(entry.status_expect));
}
fprintf(stdout, "%s\n", GetOptionsString(entry.args).c_str());
}
diff --git a/test/correctness/tester.hpp b/test/correctness/tester.hpp
index d8462cef..113f03ef 100644
--- a/test/correctness/tester.hpp
+++ b/test/correctness/tester.hpp
@@ -22,14 +22,14 @@
#include <vector>
#include <memory>
+#include "utilities/utilities.hpp"
+
// The libraries
#ifdef CLBLAST_REF_CLBLAS
#include <clBLAS.h>
#endif
#include "clblast.h"
-#include "utilities/utilities.hpp"
-
namespace clblast {
// =================================================================================================
diff --git a/test/performance/client.hpp b/test/performance/client.hpp
index 4554c67f..4b3e17c7 100644
--- a/test/performance/client.hpp
+++ b/test/performance/client.hpp
@@ -25,14 +25,14 @@
#include <vector>
#include <utility>
+#include "utilities/utilities.hpp"
+
// The libraries to test
#ifdef CLBLAST_REF_CLBLAS
#include <clBLAS.h>
#endif
#include "clblast.h"
-#include "utilities/utilities.hpp"
-
namespace clblast {
// =================================================================================================