summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cache.hpp8
-rw-r--r--src/clpp11.hpp34
-rw-r--r--src/routine.cpp14
-rw-r--r--src/routine.hpp1
-rw-r--r--src/utilities/utilities.cpp13
-rw-r--r--src/utilities/utilities.hpp4
-rw-r--r--test/test_utilities.cpp4
-rw-r--r--test/test_utilities.hpp4
8 files changed, 48 insertions, 34 deletions
diff --git a/src/cache.hpp b/src/cache.hpp
index f6a948b6..1c8c9d4c 100644
--- a/src/cache.hpp
+++ b/src/cache.hpp
@@ -80,8 +80,8 @@ extern template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const
// The key struct for the cache of compiled OpenCL programs (context-dependent)
// Order of fields: context, device_id, precision, routine_name (smaller fields first)
-typedef std::tuple<cl_context, cl_device_id, Precision, std::string> ProgramKey;
-typedef std::tuple<const cl_context &, const cl_device_id &, const Precision &, const std::string &> ProgramKeyRef;
+typedef std::tuple<RawContext, RawDeviceID, Precision, std::string> ProgramKey;
+typedef std::tuple<const RawContext &, const RawDeviceID &, const Precision &, const std::string &> ProgramKeyRef;
typedef Cache<ProgramKey, Program> ProgramCache;
@@ -94,8 +94,8 @@ class Database;
// The key struct for the cache of database maps.
// Order of fields: platform_id, device_id, precision, kernel_name (smaller fields first)
-typedef std::tuple<cl_platform_id, cl_device_id, Precision, std::string> DatabaseKey;
-typedef std::tuple<const cl_platform_id &, const cl_device_id &, const Precision &, const std::string &> DatabaseKeyRef;
+typedef std::tuple<RawPlatformID, RawDeviceID, Precision, std::string> DatabaseKey;
+typedef std::tuple<const RawPlatformID &, const RawDeviceID &, const Precision &, const std::string &> DatabaseKeyRef;
typedef Cache<DatabaseKey, Database> DatabaseCache;
diff --git a/src/clpp11.hpp b/src/clpp11.hpp
index b9b7fd5b..97045644 100644
--- a/src/clpp11.hpp
+++ b/src/clpp11.hpp
@@ -79,6 +79,9 @@ class CLCudaAPIError : public ErrorCode<DeviceError, cl_int> {
}
};
+// Exception returned when building a program
+using CLCudaAPIBuildError = CLCudaAPIError;
+
// =================================================================================================
// Error occurred in OpenCL
@@ -179,7 +182,7 @@ class Platform {
}
// Accessor to the private data-member
- const cl_platform_id& operator()() const { return platform_; }
+ const RawPlatformID& operator()() const { return platform_; }
private:
cl_platform_id platform_;
@@ -208,6 +211,9 @@ inline std::vector<Platform> GetAllPlatforms() {
// =================================================================================================
+// Raw device ID type
+using RawDeviceID = cl_device_id;
+
// C++11 version of 'cl_device_id'
class Device {
public:
@@ -270,6 +276,13 @@ class Device {
const auto extensions = Capabilities();
return extensions.find(extension) != std::string::npos;
}
+ bool SupportsFP64() const {
+ return HasExtension("cl_khr_fp64");
+ }
+ bool SupportsFP16() const {
+ if (Name() == "Mali-T628") { return true; } // supports fp16 but not cl_khr_fp16 officially
+ return HasExtension("cl_khr_fp16");
+ }
size_t CoreClock() const {
return static_cast<size_t>(GetInfo<cl_uint>(CL_DEVICE_MAX_CLOCK_FREQUENCY));
@@ -334,7 +347,7 @@ class Device {
}
// Accessor to the private data-member
- const cl_device_id& operator()() const { return device_; }
+ const RawDeviceID& operator()() const { return device_; }
private:
cl_device_id device_;
@@ -368,6 +381,9 @@ class Device {
// =================================================================================================
+// Raw context type
+using RawContext = cl_context;
+
// C++11 version of 'cl_context'
class Context {
public:
@@ -391,8 +407,8 @@ class Context {
}
// Accessor to the private data-member
- const cl_context& operator()() const { return *context_; }
- cl_context* pointer() const { return &(*context_); }
+ const RawContext& operator()() const { return *context_; }
+ RawContext* pointer() const { return &(*context_); }
private:
std::shared_ptr<cl_context> context_;
};
@@ -446,6 +462,11 @@ class Program {
CheckError(clBuildProgram(*program_, 1, &dev, options_string.c_str(), nullptr, nullptr));
}
+ // Confirms whether a certain status code is an actual compilation error or warning
+ bool StatusIsCompilationWarningOrError(const cl_int status) const {
+ return (status == CL_BUILD_PROGRAM_FAILURE);
+ }
+
// Retrieves the warning/error message from the compiler (if any)
std::string GetBuildInfo(const Device &device) const {
auto bytes = size_t{0};
@@ -476,6 +497,9 @@ class Program {
// =================================================================================================
+// Raw command-queue type
+using RawCommandQueue = cl_command_queue;
+
// C++11 version of 'cl_command_queue'
class Queue {
public:
@@ -522,7 +546,7 @@ class Queue {
}
// Accessor to the private data-member
- const cl_command_queue& operator()() const { return *queue_; }
+ const RawCommandQueue& operator()() const { return *queue_; }
private:
std::shared_ptr<cl_command_queue> queue_;
};
diff --git a/src/routine.cpp b/src/routine.cpp
index 0a1b6e30..aaa85fde 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -60,7 +60,6 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
event_(event),
context_(queue_.GetContext()),
device_(queue_.GetDevice()),
- platform_(device_.PlatformID()),
db_(kernel_names) {
InitDatabase(userDatabase);
@@ -68,18 +67,19 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
}
void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatabase) {
+ const auto platform_id = device_.PlatformID();
for (const auto &kernel_name : kernel_names_) {
// Queries the cache to see whether or not the kernel parameter database is already there
bool has_db;
- db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_, device_(), precision_, kernel_name },
+ db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_id, device_(), precision_, kernel_name },
&has_db);
if (has_db) { continue; }
// Builds the parameter database for this device and routine set and stores it in the cache
log_debug("Searching database for kernel '" + kernel_name + "'");
db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
- DatabaseCache::Instance().Store(DatabaseKey{ platform_, device_(), precision_, kernel_name },
+ DatabaseCache::Instance().Store(DatabaseKey{ platform_id, device_(), precision_, kernel_name },
Database{ db_(kernel_name) });
}
}
@@ -123,13 +123,13 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
// Otherwise, the kernel will be compiled and program will be built. Both the binary and the
// program will be added to the cache.
- // Inspects whether or not cl_khr_fp64 is supported in case of double precision
+ // Inspects whether or not FP64 is supported in case of double precision
if ((precision_ == Precision::kDouble && !PrecisionSupported<double>(device_)) ||
(precision_ == Precision::kComplexDouble && !PrecisionSupported<double2>(device_))) {
throw RuntimeErrorCode(StatusCode::kNoDoublePrecision);
}
- // As above, but for cl_khr_fp16 (half precision)
+ // As above, but for FP16 (half precision)
if (precision_ == Precision::kHalf && !PrecisionSupported<half>(device_)) {
throw RuntimeErrorCode(StatusCode::kNoHalfPrecision);
}
@@ -188,8 +188,8 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
program_ = Program(context_, source_string);
try {
program_.Build(device_, options);
- } catch (const CLCudaAPIError &e) {
- if (e.status() == CL_BUILD_PROGRAM_FAILURE) {
+ } catch (const CLCudaAPIBuildError &e) {
+ if (program_.StatusIsCompilationWarningOrError(e.status())) {
fprintf(stdout, "OpenCL compiler error/warning: %s\n",
program_.GetBuildInfo(device_).c_str());
}
diff --git a/src/routine.hpp b/src/routine.hpp
index e77e35ad..a8f1cb6a 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -75,7 +75,6 @@ class Routine {
EventPointer event_;
const Context context_;
const Device device_;
- const cl_platform_id platform_;
// Compiled program (either retrieved from cache or compiled in slow path)
Program program_;
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp
index 4b8d5a09..a5c1d45e 100644
--- a/src/utilities/utilities.cpp
+++ b/src/utilities/utilities.cpp
@@ -391,16 +391,9 @@ template <> Precision PrecisionValue<double2>() { return Precision::kComplexDoub
// Returns false is this precision is not supported by the device
template <> bool PrecisionSupported<float>(const Device &) { return true; }
template <> bool PrecisionSupported<float2>(const Device &) { return true; }
-template <> bool PrecisionSupported<double>(const Device &device) {
- return device.HasExtension(kKhronosDoublePrecision);
-}
-template <> bool PrecisionSupported<double2>(const Device &device) {
- return device.HasExtension(kKhronosDoublePrecision);
-}
-template <> bool PrecisionSupported<half>(const Device &device) {
- if (device.Name() == "Mali-T628") { return true; } // supports fp16 but not cl_khr_fp16 officially
- return device.HasExtension(kKhronosHalfPrecision);
-}
+template <> bool PrecisionSupported<double>(const Device &device) { return device.SupportsFP64(); }
+template <> bool PrecisionSupported<double2>(const Device &device) { return device.SupportsFP64(); }
+template <> bool PrecisionSupported<half>(const Device &device) { return device.SupportsFP16(); }
// =================================================================================================
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index e45c606c..b2949c27 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -31,15 +31,13 @@ namespace clblast {
// =================================================================================================
// Shorthands for half-precision
-using half = cl_half; // based on the OpenCL type, which is actually an 'unsigned short'
+using half = unsigned short; // the 'cl_half' OpenCL type is actually an 'unsigned short'
// Shorthands for complex data-types
using float2 = std::complex<float>;
using double2 = std::complex<double>;
// Khronos OpenCL extensions
-const std::string kKhronosHalfPrecision = "cl_khr_fp16";
-const std::string kKhronosDoublePrecision = "cl_khr_fp64";
const std::string kKhronosAttributesAMD = "cl_amd_device_attribute_query";
const std::string kKhronosAttributesNVIDIA = "cl_nv_device_attribute_query";
diff --git a/test/test_utilities.cpp b/test/test_utilities.cpp
index b8fd94a9..579eb61c 100644
--- a/test/test_utilities.cpp
+++ b/test/test_utilities.cpp
@@ -88,7 +88,7 @@ void FloatToHalfBuffer(std::vector<half>& result, const std::vector<float>& sour
}
// As above, but now for OpenCL data-types instead of std::vectors
-Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, cl_command_queue queue_raw) {
+Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, RawCommandQueue queue_raw) {
const auto size = source.GetSize() / sizeof(half);
auto queue = Queue(queue_raw);
auto context = queue.GetContext();
@@ -99,7 +99,7 @@ Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, cl_command_queue que
result.Write(queue, size, result_cpu);
return result;
}
-void FloatToHalfBuffer(Buffer<half>& result, const Buffer<float>& source, cl_command_queue queue_raw) {
+void FloatToHalfBuffer(Buffer<half>& result, const Buffer<float>& source, RawCommandQueue queue_raw) {
const auto size = source.GetSize() / sizeof(float);
auto queue = Queue(queue_raw);
auto context = queue.GetContext();
diff --git a/test/test_utilities.hpp b/test/test_utilities.hpp
index fc50a754..fe7a9cd2 100644
--- a/test/test_utilities.hpp
+++ b/test/test_utilities.hpp
@@ -89,8 +89,8 @@ std::vector<float> HalfToFloatBuffer(const std::vector<half>& source);
void FloatToHalfBuffer(std::vector<half>& result, const std::vector<float>& source);
// As above, but now for OpenCL data-types instead of std::vectors
-Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, cl_command_queue queue_raw);
-void FloatToHalfBuffer(Buffer<half>& result, const Buffer<float>& source, cl_command_queue queue_raw);
+Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, RawCommandQueue queue_raw);
+void FloatToHalfBuffer(Buffer<half>& result, const Buffer<float>& source, RawCommandQueue queue_raw);
// =================================================================================================
} // namespace clblast