summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/clblast.cpp121
-rw-r--r--src/clblast_c.cpp2
-rw-r--r--src/clpp11.hpp51
-rw-r--r--src/routine.cpp15
-rw-r--r--src/routines/common.hpp2
-rw-r--r--src/utilities/clblast_exceptions.hpp2
-rw-r--r--src/utilities/utilities.hpp2
7 files changed, 108 insertions, 87 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 4bb4e0b3..e0f8add2 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -15,8 +15,8 @@
#include <string>
-#include "clblast.h"
#include "cache.hpp"
+#include "clblast.h"
// BLAS level-1 includes
#include "routines/level1/xswap.hpp"
@@ -2170,6 +2170,71 @@ StatusCode ClearCache() {
return StatusCode::kSuccess;
}
+template <typename Real, typename Complex>
+void FillCacheForPrecision(Queue &queue) {
+ try {
+
+ // Runs all the level 1 set-up functions
+ Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
+ Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
+ Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr);
+ Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr);
+ Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr);
+ Xdot<Real>(queue, nullptr);
+ Xdotu<Complex>(queue, nullptr);
+ Xdotc<Complex>(queue, nullptr);
+ Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr);
+ Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr);
+ Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr);
+ Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr);
+ Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr);
+ Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr);
+
+ // Runs all the level 2 set-up functions
+ Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr);
+ Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr);
+ Xhemv<Complex>(queue, nullptr);
+ Xhbmv<Complex>(queue, nullptr);
+ Xhpmv<Complex>(queue, nullptr);
+ Xsymv<Real>(queue, nullptr);
+ Xsbmv<Real>(queue, nullptr);
+ Xspmv<Real>(queue, nullptr);
+ Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr);
+ Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr);
+ Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr);
+ Xger<Real>(queue, nullptr);
+ Xgeru<Complex>(queue, nullptr);
+ Xgerc<Complex>(queue, nullptr);
+ Xher<Complex,Real>(queue, nullptr);
+ Xhpr<Complex,Real>(queue, nullptr);
+ Xher2<Complex>(queue, nullptr);
+ Xhpr2<Complex>(queue, nullptr);
+ Xsyr<Real>(queue, nullptr);
+ Xspr<Real>(queue, nullptr);
+ Xsyr2<Real>(queue, nullptr);
+ Xspr2<Real>(queue, nullptr);
+
+ // Runs all the level 3 set-up functions
+ Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr);
+ Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr);
+ Xhemm<Complex>(queue, nullptr);
+ Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr);
+ Xherk<Complex,Real>(queue, nullptr);
+ Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr);
+ Xher2k<Complex,Real>(queue, nullptr);
+ Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr);
+
+ // Runs all the non-BLAS set-up functions
+ Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr);
+
+ } catch(const RuntimeErrorCode &e) {
+ if (e.status() != StatusCode::kNoDoublePrecision &&
+ e.status() != StatusCode::kNoHalfPrecision) {
+ throw;
+ }
+ }
+}
+
// Fills the cache with all binaries for a specific device
// TODO: Add half-precision FP16 set-up calls
StatusCode FillCache(const cl_device_id device) {
@@ -2180,58 +2245,8 @@ StatusCode FillCache(const cl_device_id device) {
auto context = Context(device_cpp);
auto queue = Queue(context, device_cpp);
- // Runs all the level 1 set-up functions
- Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr);
- Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr);
- Xscal<float>(queue, nullptr); Xscal<double>(queue, nullptr); Xscal<float2>(queue, nullptr); Xscal<double2>(queue, nullptr);
- Xcopy<float>(queue, nullptr); Xcopy<double>(queue, nullptr); Xcopy<float2>(queue, nullptr); Xcopy<double2>(queue, nullptr);
- Xaxpy<float>(queue, nullptr); Xaxpy<double>(queue, nullptr); Xaxpy<float2>(queue, nullptr); Xaxpy<double2>(queue, nullptr);
- Xdot<float>(queue, nullptr); Xdot<double>(queue, nullptr);
- Xdotu<float2>(queue, nullptr); Xdotu<double2>(queue, nullptr);
- Xdotc<float2>(queue, nullptr); Xdotc<double2>(queue, nullptr);
- Xnrm2<float>(queue, nullptr); Xnrm2<double>(queue, nullptr); Xnrm2<float2>(queue, nullptr); Xnrm2<double2>(queue, nullptr);
- Xasum<float>(queue, nullptr); Xasum<double>(queue, nullptr); Xasum<float2>(queue, nullptr); Xasum<double2>(queue, nullptr);
- Xsum<float>(queue, nullptr); Xsum<double>(queue, nullptr); Xsum<float2>(queue, nullptr); Xsum<double2>(queue, nullptr);
- Xamax<float>(queue, nullptr); Xamax<double>(queue, nullptr); Xamax<float2>(queue, nullptr); Xamax<double2>(queue, nullptr);
- Xmax<float>(queue, nullptr); Xmax<double>(queue, nullptr); Xmax<float2>(queue, nullptr); Xmax<double2>(queue, nullptr);
- Xmin<float>(queue, nullptr); Xmin<double>(queue, nullptr); Xmin<float2>(queue, nullptr); Xmin<double2>(queue, nullptr);
-
- // Runs all the level 2 set-up functions
- Xgemv<float>(queue, nullptr); Xgemv<double>(queue, nullptr); Xgemv<float2>(queue, nullptr); Xgemv<double2>(queue, nullptr);
- Xgbmv<float>(queue, nullptr); Xgbmv<double>(queue, nullptr); Xgbmv<float2>(queue, nullptr); Xgbmv<double2>(queue, nullptr);
- Xhemv<float2>(queue, nullptr); Xhemv<double2>(queue, nullptr);
- Xhbmv<float2>(queue, nullptr); Xhbmv<double2>(queue, nullptr);
- Xhpmv<float2>(queue, nullptr); Xhpmv<double2>(queue, nullptr);
- Xsymv<float>(queue, nullptr); Xsymv<double>(queue, nullptr);
- Xsbmv<float>(queue, nullptr); Xsbmv<double>(queue, nullptr);
- Xspmv<float>(queue, nullptr); Xspmv<double>(queue, nullptr);
- Xtrmv<float>(queue, nullptr); Xtrmv<double>(queue, nullptr); Xtrmv<float2>(queue, nullptr); Xtrmv<double2>(queue, nullptr);
- Xtbmv<float>(queue, nullptr); Xtbmv<double>(queue, nullptr); Xtbmv<float2>(queue, nullptr); Xtbmv<double2>(queue, nullptr);
- Xtpmv<float>(queue, nullptr); Xtpmv<double>(queue, nullptr); Xtpmv<float2>(queue, nullptr); Xtpmv<double2>(queue, nullptr);
- Xger<float>(queue, nullptr); Xger<double>(queue, nullptr);
- Xgeru<float2>(queue, nullptr); Xgeru<double2>(queue, nullptr);
- Xgerc<float2>(queue, nullptr); Xgerc<double2>(queue, nullptr);
- Xher<float2,float>(queue, nullptr); Xher<double2,double>(queue, nullptr);
- Xhpr<float2,float>(queue, nullptr); Xhpr<double2,double>(queue, nullptr);
- Xher2<float2>(queue, nullptr); Xher2<double2>(queue, nullptr);
- Xhpr2<float2>(queue, nullptr); Xhpr2<double2>(queue, nullptr);
- Xsyr<float>(queue, nullptr); Xsyr<double>(queue, nullptr);
- Xspr<float>(queue, nullptr); Xspr<double>(queue, nullptr);
- Xsyr2<float>(queue, nullptr); Xsyr2<double>(queue, nullptr);
- Xspr2<float>(queue, nullptr); Xspr2<double>(queue, nullptr);
-
- // Runs all the level 3 set-up functions
- Xgemm<float>(queue, nullptr); Xgemm<double>(queue, nullptr); Xgemm<float2>(queue, nullptr); Xgemm<double2>(queue, nullptr);
- Xsymm<float>(queue, nullptr); Xsymm<double>(queue, nullptr); Xsymm<float2>(queue, nullptr); Xsymm<double2>(queue, nullptr);
- Xhemm<float2>(queue, nullptr); Xhemm<double2>(queue, nullptr);
- Xsyrk<float>(queue, nullptr); Xsyrk<double>(queue, nullptr); Xsyrk<float2>(queue, nullptr); Xsyrk<double2>(queue, nullptr);
- Xherk<float2,float>(queue, nullptr); Xherk<double2,double>(queue, nullptr);
- Xsyr2k<float>(queue, nullptr); Xsyr2k<double>(queue, nullptr); Xsyr2k<float2>(queue, nullptr); Xsyr2k<double2>(queue, nullptr);
- Xher2k<float2,float>(queue, nullptr); Xher2k<double2,double>(queue, nullptr);
- Xtrmm<float>(queue, nullptr); Xtrmm<double>(queue, nullptr); Xtrmm<float2>(queue, nullptr); Xtrmm<double2>(queue, nullptr);
-
- // Runs all the level 3 set-up functions
- Xomatcopy<float>(queue, nullptr); Xomatcopy<double>(queue, nullptr); Xomatcopy<float2>(queue, nullptr); Xomatcopy<double2>(queue, nullptr);
+ FillCacheForPrecision<float, float2>(queue);
+ FillCacheForPrecision<double, double2>(queue);
} catch (...) { return DispatchException(); }
return StatusCode::kSuccess;
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index 59e4cd16..e4f2b3ed 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -13,9 +13,9 @@
#include <string>
+#include "utilities/utilities.hpp"
#include "clblast_c.h"
#include "clblast.h"
-#include "utilities/utilities.hpp"
// Shortcuts to the clblast namespace
using float2 = clblast::float2;
diff --git a/src/clpp11.hpp b/src/clpp11.hpp
index 0383f53a..c984661c 100644
--- a/src/clpp11.hpp
+++ b/src/clpp11.hpp
@@ -333,7 +333,10 @@ class Context {
// Regular constructor with memory management
explicit Context(const Device &device):
- context_(new cl_context, [](cl_context* c) { CheckErrorDtor(clReleaseContext(*c)); delete c; }) {
+ context_(new cl_context, [](cl_context* c) {
+ if (*c) { CheckErrorDtor(clReleaseContext(*c)); }
+ delete c;
+ }) {
auto status = CL_SUCCESS;
const cl_device_id dev = device();
*context_ = clCreateContext(nullptr, 1, &dev, nullptr, nullptr, &status);
@@ -355,33 +358,37 @@ using ContextPointer = cl_context*;
// Enumeration of build statuses of the run-time compilation process
enum class BuildStatus { kSuccess, kError, kInvalid };
-// C++11 version of 'cl_program'. Additionally holds the program's source code.
+// C++11 version of 'cl_program'.
class Program {
public:
// Note that there is no constructor based on the regular OpenCL data-type because of extra state
// Source-based constructor with memory management
- explicit Program(const Context &context, std::string source):
- program_(new cl_program, [](cl_program* p) { CheckErrorDtor(clReleaseProgram(*p)); delete p; }),
- length_(source.length()),
- source_(std::move(source)),
- source_ptr_(&source_[0]) {
+ explicit Program(const Context &context, const std::string &source):
+ program_(new cl_program, [](cl_program* p) {
+ if (*p) { CheckErrorDtor(clReleaseProgram(*p)); }
+ delete p;
+ }) {
+ const char *source_ptr = &source[0];
+ size_t length = source.length();
auto status = CL_SUCCESS;
- *program_ = clCreateProgramWithSource(context(), 1, &source_ptr_, &length_, &status);
+ *program_ = clCreateProgramWithSource(context(), 1, &source_ptr, &length, &status);
CLError::Check(status, "clCreateProgramWithSource");
}
// Binary-based constructor with memory management
- explicit Program(const Device &device, const Context &context, const std::string& binary):
- program_(new cl_program, [](cl_program* p) { CheckErrorDtor(clReleaseProgram(*p)); delete p; }),
- length_(binary.length()),
- source_(binary),
- source_ptr_(&source_[0]) {
+ explicit Program(const Device &device, const Context &context, const std::string &binary):
+ program_(new cl_program, [](cl_program* p) {
+ if (*p) { CheckErrorDtor(clReleaseProgram(*p)); }
+ delete p;
+ }) {
+ const char *binary_ptr = &binary[0];
+ size_t length = binary.length();
auto status1 = CL_SUCCESS;
auto status2 = CL_SUCCESS;
const cl_device_id dev = device();
- *program_ = clCreateProgramWithBinary(context(), 1, &dev, &length_,
- reinterpret_cast<const unsigned char**>(&source_ptr_),
+ *program_ = clCreateProgramWithBinary(context(), 1, &dev, &length,
+ reinterpret_cast<const unsigned char**>(&binary_ptr),
&status1, &status2);
CLError::Check(status1, "clCreateProgramWithBinary (binary status)");
CLError::Check(status2, "clCreateProgramWithBinary");
@@ -421,9 +428,6 @@ class Program {
const cl_program& operator()() const { return *program_; }
private:
std::shared_ptr<cl_program> program_;
- size_t length_;
- std::string source_; // Note: the source can also be a binary or IR
- const char* source_ptr_;
};
// =================================================================================================
@@ -440,8 +444,10 @@ class Queue {
// Regular constructor with memory management
explicit Queue(const Context &context, const Device &device):
- queue_(new cl_command_queue, [](cl_command_queue* s) { CheckErrorDtor(clReleaseCommandQueue(*s));
- delete s; }) {
+ queue_(new cl_command_queue, [](cl_command_queue* s) {
+ if (*s) { CheckErrorDtor(clReleaseCommandQueue(*s)); }
+ delete s;
+ }) {
auto status = CL_SUCCESS;
*queue_ = clCreateCommandQueue(context(), device(), CL_QUEUE_PROFILING_ENABLE, &status);
CLError::Check(status, "clCreateCommandQueue");
@@ -665,7 +671,10 @@ class Kernel {
// Regular constructor with memory management
explicit Kernel(const Program &program, const std::string &name):
- kernel_(new cl_kernel, [](cl_kernel* k) { CheckErrorDtor(clReleaseKernel(*k)); delete k; }) {
+ kernel_(new cl_kernel, [](cl_kernel* k) {
+ if (*k) { CheckErrorDtor(clReleaseKernel(*k)); }
+ delete k;
+ }) {
auto status = CL_SUCCESS;
*kernel_ = clCreateKernel(program(), name.c_str(), &status);
CLError::Check(status, "clCreateKernel");
diff --git a/src/routine.cpp b/src/routine.cpp
index acafb0d2..d5a6b589 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -52,24 +52,21 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
auto program = Program(device_, context_, binary);
program.Build(device_, options);
StoreProgramToCache(program, context_, precision_, routine_name_);
+ return;
}
// Otherwise, the kernel will be compiled and program will be built. Both the binary and the
// program will be added to the cache.
// Inspects whether or not cl_khr_fp64 is supported in case of double precision
- const auto extensions = device_.Capabilities();
- if (precision_ == Precision::kDouble || precision_ == Precision::kComplexDouble) {
- if (extensions.find(kKhronosDoublePrecision) == std::string::npos) {
- throw RuntimeErrorCode(StatusCode::kNoDoublePrecision);
- }
+ if ((precision_ == Precision::kDouble && !PrecisionSupported<double>(device_)) ||
+ (precision_ == Precision::kComplexDouble && !PrecisionSupported<double2>(device_))) {
+ throw RuntimeErrorCode(StatusCode::kNoDoublePrecision);
}
// As above, but for cl_khr_fp16 (half precision)
- if (precision_ == Precision::kHalf) {
- if (extensions.find(kKhronosHalfPrecision) == std::string::npos) {
- throw RuntimeErrorCode(StatusCode::kNoHalfPrecision);
- }
+ if (precision_ == Precision::kHalf && !PrecisionSupported<half>(device_)) {
+ throw RuntimeErrorCode(StatusCode::kNoHalfPrecision);
}
// Collects the parameters for this device in the form of defines, and adds the precision
diff --git a/src/routines/common.hpp b/src/routines/common.hpp
index 53ca6355..7c211c0d 100644
--- a/src/routines/common.hpp
+++ b/src/routines/common.hpp
@@ -19,8 +19,8 @@
#include <string>
#include <vector>
-#include "clblast.h"
#include "clpp11.hpp"
+#include "clblast.h"
#include "database/database.hpp"
namespace clblast {
diff --git a/src/utilities/clblast_exceptions.hpp b/src/utilities/clblast_exceptions.hpp
index f3c7b9a3..0d0033b6 100644
--- a/src/utilities/clblast_exceptions.hpp
+++ b/src/utilities/clblast_exceptions.hpp
@@ -16,8 +16,8 @@
#ifndef CLBLAST_EXCEPTIONS_H_
#define CLBLAST_EXCEPTIONS_H_
-#include "clblast.h"
#include "clpp11.hpp"
+#include "clblast.h"
namespace clblast {
// =================================================================================================
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index 20587bd4..a1d4e2db 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -21,9 +21,9 @@
#include <functional>
#include <complex>
+#include "clpp11.hpp"
#include "clblast.h"
#include "clblast_half.h"
-#include "clpp11.hpp"
#include "utilities/clblast_exceptions.hpp"
#include "utilities/msvc.hpp"