summaryrefslogtreecommitdiff
path: root/src/utilities
diff options
context:
space:
mode:
Diffstat (limited to 'src/utilities')
-rw-r--r--src/utilities/buffer_test.hpp2
-rw-r--r--src/utilities/clblast_exceptions.cpp2
-rw-r--r--src/utilities/clblast_exceptions.hpp3
-rw-r--r--src/utilities/utilities.cpp29
-rw-r--r--src/utilities/utilities.hpp13
5 files changed, 24 insertions, 25 deletions
diff --git a/src/utilities/buffer_test.hpp b/src/utilities/buffer_test.hpp
index b5693181..fd071434 100644
--- a/src/utilities/buffer_test.hpp
+++ b/src/utilities/buffer_test.hpp
@@ -15,7 +15,7 @@
#ifndef CLBLAST_BUFFER_TEST_H_
#define CLBLAST_BUFFER_TEST_H_
-#include "clblast.h"
+#include "utilities/utilities.hpp"
namespace clblast {
// =================================================================================================
diff --git a/src/utilities/clblast_exceptions.cpp b/src/utilities/clblast_exceptions.cpp
index 96f10860..32526215 100644
--- a/src/utilities/clblast_exceptions.cpp
+++ b/src/utilities/clblast_exceptions.cpp
@@ -55,7 +55,7 @@ StatusCode DispatchException()
} catch (BLASError &e) {
// no message is printed for invalid argument errors
status = e.status();
- } catch (CLError &e) {
+ } catch (CLCudaAPIError &e) {
message = e.what();
status = static_cast<StatusCode>(e.status());
} catch (RuntimeErrorCode &e) {
diff --git a/src/utilities/clblast_exceptions.hpp b/src/utilities/clblast_exceptions.hpp
index 0d0033b6..a790be9c 100644
--- a/src/utilities/clblast_exceptions.hpp
+++ b/src/utilities/clblast_exceptions.hpp
@@ -16,8 +16,7 @@
#ifndef CLBLAST_EXCEPTIONS_H_
#define CLBLAST_EXCEPTIONS_H_
-#include "clpp11.hpp"
-#include "clblast.h"
+#include "utilities/utilities.hpp"
namespace clblast {
// =================================================================================================
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp
index 4b8d5a09..f2574104 100644
--- a/src/utilities/utilities.cpp
+++ b/src/utilities/utilities.cpp
@@ -391,16 +391,9 @@ template <> Precision PrecisionValue<double2>() { return Precision::kComplexDoub
// Returns false is this precision is not supported by the device
template <> bool PrecisionSupported<float>(const Device &) { return true; }
template <> bool PrecisionSupported<float2>(const Device &) { return true; }
-template <> bool PrecisionSupported<double>(const Device &device) {
- return device.HasExtension(kKhronosDoublePrecision);
-}
-template <> bool PrecisionSupported<double2>(const Device &device) {
- return device.HasExtension(kKhronosDoublePrecision);
-}
-template <> bool PrecisionSupported<half>(const Device &device) {
- if (device.Name() == "Mali-T628") { return true; } // supports fp16 but not cl_khr_fp16 officially
- return device.HasExtension(kKhronosHalfPrecision);
-}
+template <> bool PrecisionSupported<double>(const Device &device) { return device.SupportsFP64(); }
+template <> bool PrecisionSupported<double2>(const Device &device) { return device.SupportsFP64(); }
+template <> bool PrecisionSupported<half>(const Device &device) { return device.SupportsFP16(); }
// =================================================================================================
@@ -420,13 +413,17 @@ std::string GetDeviceVendor(const Device& device) {
// Mid-level info
std::string GetDeviceArchitecture(const Device& device) {
auto device_architecture = std::string{""};
- if (device.HasExtension(kKhronosAttributesNVIDIA)) {
+ #ifdef CUDA_API
device_architecture = device.NVIDIAComputeCapability();
- }
- else if (device.HasExtension(kKhronosAttributesAMD)) {
- device_architecture = device.Name(); // Name is architecture for AMD APP and AMD ROCm
- }
- // Note: no else - 'device_architecture' might be the empty string
+ #else
+ if (device.HasExtension(kKhronosAttributesNVIDIA)) {
+ device_architecture = device.NVIDIAComputeCapability();
+ }
+ else if (device.HasExtension(kKhronosAttributesAMD)) {
+ device_architecture = device.Name(); // Name is architecture for AMD APP and AMD ROCm
+ }
+ // Note: no else - 'device_architecture' might be the empty string
+ #endif
for (auto &find_and_replace : device_mapping::kArchitectureNames) { // replacing to common names
if (device_architecture == find_and_replace.first) { device_architecture = find_and_replace.second; }
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index e45c606c..f56226be 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -21,8 +21,13 @@
#include <complex>
#include <random>
-#include "clpp11.hpp"
-#include "clblast.h"
+#ifdef OPENCL_API
+ #include "clpp11.hpp"
+ #include "clblast.h"
+#elif CUDA_API
+ #include "cupp11.hpp"
+ #include "clblast_cuda.h"
+#endif
#include "clblast_half.h"
#include "utilities/clblast_exceptions.hpp"
#include "utilities/msvc.hpp"
@@ -31,15 +36,13 @@ namespace clblast {
// =================================================================================================
// Shorthands for half-precision
-using half = cl_half; // based on the OpenCL type, which is actually an 'unsigned short'
+using half = unsigned short; // the 'cl_half' OpenCL type is actually an 'unsigned short'
// Shorthands for complex data-types
using float2 = std::complex<float>;
using double2 = std::complex<double>;
// Khronos OpenCL extensions
-const std::string kKhronosHalfPrecision = "cl_khr_fp16";
-const std::string kKhronosDoublePrecision = "cl_khr_fp64";
const std::string kKhronosAttributesAMD = "cl_amd_device_attribute_query";
const std::string kKhronosAttributesNVIDIA = "cl_nv_device_attribute_query";