summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-06 10:05:28 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-06 10:05:28 +0100
commitce069545d4b9ac32a094117de75919087a7bc21e (patch)
tree94aa7e3293600dce1cf4dee83cb00d4ffb724586 /src
parent44431daecc63cc4ead3208327bcd70834b3f4bdb (diff)
Added CUDA interface to get temporary-buffer size for GEMM routine
Diffstat (limited to 'src')
-rw-r--r--src/clblast.cpp12
-rw-r--r--src/clblast_cuda.cpp52
2 files changed, 58 insertions, 6 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 461cf31f..f5e2f1be 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2345,7 +2345,7 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
const size_t a_offset, const size_t a_ld,
const size_t b_offset, const size_t b_ld,
const size_t c_offset, const size_t c_ld,
- RawCommandQueue* queue, size_t& temp_buffer_size) {
+ cl_command_queue* queue, size_t& temp_buffer_size) {
try {
// Retrieves the tuning database
@@ -2371,23 +2371,23 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
// =================================================================================================
} // namespace clblast
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 187443eb..21514c74 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -2437,4 +2437,56 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
const CUcontext, const CUdevice);
// =================================================================================================
+
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ const CUdevice device, size_t& temp_buffer_size) {
+ try {
+
+ // Retrieves the tuning database
+ const auto device_cpp = Device(device);
+ const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
+ Databases db(kernel_names);
+ Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue<T>(), {}, db);
+
+ // Computes the buffer size
+ if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
+ temp_buffer_size = 0;
+ }
+ else {
+ temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
+ a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
+ db["MWG"], db["NWG"], db["KWG"]);
+ }
+ temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+
+// =================================================================================================
} // namespace clblast