diff options
Diffstat (limited to 'src/utilities')
-rw-r--r-- | src/utilities/buffer_test.hpp | 2 | ||||
-rw-r--r-- | src/utilities/clblast_exceptions.cpp | 2 | ||||
-rw-r--r-- | src/utilities/clblast_exceptions.hpp | 3 | ||||
-rw-r--r-- | src/utilities/utilities.cpp | 29 | ||||
-rw-r--r-- | src/utilities/utilities.hpp | 13 |
5 files changed, 24 insertions, 25 deletions
diff --git a/src/utilities/buffer_test.hpp b/src/utilities/buffer_test.hpp index b5693181..fd071434 100644 --- a/src/utilities/buffer_test.hpp +++ b/src/utilities/buffer_test.hpp @@ -15,7 +15,7 @@ #ifndef CLBLAST_BUFFER_TEST_H_ #define CLBLAST_BUFFER_TEST_H_ -#include "clblast.h" +#include "utilities/utilities.hpp" namespace clblast { // ================================================================================================= diff --git a/src/utilities/clblast_exceptions.cpp b/src/utilities/clblast_exceptions.cpp index 96f10860..32526215 100644 --- a/src/utilities/clblast_exceptions.cpp +++ b/src/utilities/clblast_exceptions.cpp @@ -55,7 +55,7 @@ StatusCode DispatchException() } catch (BLASError &e) { // no message is printed for invalid argument errors status = e.status(); - } catch (CLError &e) { + } catch (CLCudaAPIError &e) { message = e.what(); status = static_cast<StatusCode>(e.status()); } catch (RuntimeErrorCode &e) { diff --git a/src/utilities/clblast_exceptions.hpp b/src/utilities/clblast_exceptions.hpp index 0d0033b6..a790be9c 100644 --- a/src/utilities/clblast_exceptions.hpp +++ b/src/utilities/clblast_exceptions.hpp @@ -16,8 +16,7 @@ #ifndef CLBLAST_EXCEPTIONS_H_ #define CLBLAST_EXCEPTIONS_H_ -#include "clpp11.hpp" -#include "clblast.h" +#include "utilities/utilities.hpp" namespace clblast { // ================================================================================================= diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index 4b8d5a09..f2574104 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -391,16 +391,9 @@ template <> Precision PrecisionValue<double2>() { return Precision::kComplexDoub // Returns false is this precision is not supported by the device template <> bool PrecisionSupported<float>(const Device &) { return true; } template <> bool PrecisionSupported<float2>(const Device &) { return true; } -template <> bool PrecisionSupported<double>(const Device &device) { - return device.HasExtension(kKhronosDoublePrecision); -} -template <> bool PrecisionSupported<double2>(const Device &device) { - return device.HasExtension(kKhronosDoublePrecision); -} -template <> bool PrecisionSupported<half>(const Device &device) { - if (device.Name() == "Mali-T628") { return true; } // supports fp16 but not cl_khr_fp16 officially - return device.HasExtension(kKhronosHalfPrecision); -} +template <> bool PrecisionSupported<double>(const Device &device) { return device.SupportsFP64(); } +template <> bool PrecisionSupported<double2>(const Device &device) { return device.SupportsFP64(); } +template <> bool PrecisionSupported<half>(const Device &device) { return device.SupportsFP16(); } // ================================================================================================= @@ -420,13 +413,17 @@ std::string GetDeviceVendor(const Device& device) { // Mid-level info std::string GetDeviceArchitecture(const Device& device) { auto device_architecture = std::string{""}; - if (device.HasExtension(kKhronosAttributesNVIDIA)) { + #ifdef CUDA_API device_architecture = device.NVIDIAComputeCapability(); - } - else if (device.HasExtension(kKhronosAttributesAMD)) { - device_architecture = device.Name(); // Name is architecture for AMD APP and AMD ROCm - } - // Note: no else - 'device_architecture' might be the empty string + #else + if (device.HasExtension(kKhronosAttributesNVIDIA)) { + device_architecture = device.NVIDIAComputeCapability(); + } + else if (device.HasExtension(kKhronosAttributesAMD)) { + device_architecture = device.Name(); // Name is architecture for AMD APP and AMD ROCm + } + // Note: no else - 'device_architecture' might be the empty string + #endif for (auto &find_and_replace : device_mapping::kArchitectureNames) { // replacing to common names if (device_architecture == find_and_replace.first) { device_architecture = find_and_replace.second; } diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp index e45c606c..f56226be 100644 --- a/src/utilities/utilities.hpp +++ b/src/utilities/utilities.hpp @@ -21,8 +21,13 @@ #include <complex> #include <random> -#include "clpp11.hpp" -#include "clblast.h" +#ifdef OPENCL_API + #include "clpp11.hpp" + #include "clblast.h" +#elif CUDA_API + #include "cupp11.hpp" + #include "clblast_cuda.h" +#endif #include "clblast_half.h" #include "utilities/clblast_exceptions.hpp" #include "utilities/msvc.hpp" @@ -31,15 +36,13 @@ namespace clblast { // ================================================================================================= // Shorthands for half-precision -using half = cl_half; // based on the OpenCL type, which is actually an 'unsigned short' +using half = unsigned short; // the 'cl_half' OpenCL type is actually an 'unsigned short' // Shorthands for complex data-types using float2 = std::complex<float>; using double2 = std::complex<double>; // Khronos OpenCL extensions -const std::string kKhronosHalfPrecision = "cl_khr_fp16"; -const std::string kKhronosDoublePrecision = "cl_khr_fp64"; const std::string kKhronosAttributesAMD = "cl_amd_device_attribute_query"; const std::string kKhronosAttributesNVIDIA = "cl_nv_device_attribute_query"; |