summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/clblast_cuda.h12
-rwxr-xr-xscripts/generator/generator.py4
-rw-r--r--src/clblast.cpp12
-rw-r--r--src/clblast_cuda.cpp52
-rw-r--r--test/routines/level3/xgemm.hpp16
5 files changed, 83 insertions, 13 deletions
diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h
index 0f510981..e0d1d638 100644
--- a/include/clblast_cuda.h
+++ b/include/clblast_cuda.h
@@ -69,6 +69,7 @@ enum class StatusCode {
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
+ kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small
kInvalidBatchCount = -2049, // The batch count needs to be positive
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
@@ -620,6 +621,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
// =================================================================================================
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ const CUdevice device, size_t& temp_buffer_size);
+
+// =================================================================================================
+
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
StatusCode PUBLIC_API ClearCache();
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index e0c26140..5fbce2c4 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -46,8 +46,8 @@ FILES = [
"/include/clblast_cuda.h",
"/src/clblast_cuda.cpp",
]
-HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 94, 21]
-FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 25, 3]
+HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
+FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 36, 55]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 461cf31f..f5e2f1be 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2345,7 +2345,7 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
const size_t a_offset, const size_t a_ld,
const size_t b_offset, const size_t b_ld,
const size_t c_offset, const size_t c_ld,
- RawCommandQueue* queue, size_t& temp_buffer_size) {
+ cl_command_queue* queue, size_t& temp_buffer_size) {
try {
// Retrieves the tuning database
@@ -2371,23 +2371,23 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
- const size_t, const size_t, RawCommandQueue*, size_t&);
+ const size_t, const size_t, cl_command_queue*, size_t&);
// =================================================================================================
} // namespace clblast
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 187443eb..21514c74 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -2437,4 +2437,56 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
const CUcontext, const CUdevice);
// =================================================================================================
+
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ const CUdevice device, size_t& temp_buffer_size) {
+ try {
+
+ // Retrieves the tuning database
+ const auto device_cpp = Device(device);
+ const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
+ Databases db(kernel_names);
+ Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue<T>(), {}, db);
+
+ // Computes the buffer size
+ if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
+ temp_buffer_size = 0;
+ }
+ else {
+ temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
+ a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
+ db["MWG"], db["NWG"], db["KWG"]);
+ }
+ temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+
+// =================================================================================================
} // namespace clblast
diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp
index a74cab2d..4cfa9c83 100644
--- a/test/routines/level3/xgemm.hpp
+++ b/test/routines/level3/xgemm.hpp
@@ -67,7 +67,6 @@ class TestXgemm {
args.c_size = GetSizeC(args);
// Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels
- auto queue_plain = queue();
if (V != 0) {
const auto device = queue.GetDevice();
const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests
@@ -78,9 +77,16 @@ class TestXgemm {
// Sets the size of the temporary buffer (optional argument to GEMM)
auto temp_buffer_size = size_t{0};
- GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
- args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
- &queue_plain, temp_buffer_size);
+ #ifdef OPENCL_API
+ auto queue_plain = queue();
+ GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
+ args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
+ &queue_plain, temp_buffer_size);
+ #elif CUDA_API
+ GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
+ args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
+ queue.GetDevice()(), temp_buffer_size);
+ #endif
args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero
}
@@ -117,7 +123,7 @@ class TestXgemm {
buffers.a_mat(), args.a_offset, args.a_ld,
buffers.b_mat(), args.b_offset, args.b_ld, args.beta,
buffers.c_mat(), args.c_offset, args.c_ld,
- queue.GetContext()(), queue.GetDevice()());
+ queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer
cuStreamSynchronize(queue());
#endif
return status;