summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-06 10:05:28 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-06 10:05:28 +0100
commitce069545d4b9ac32a094117de75919087a7bc21e (patch)
tree94aa7e3293600dce1cf4dee83cb00d4ffb724586 /test
parent44431daecc63cc4ead3208327bcd70834b3f4bdb (diff)
Added CUDA interface to get temporary-buffer size for GEMM routine
Diffstat (limited to 'test')
-rw-r--r--test/routines/level3/xgemm.hpp16
1 files changed, 11 insertions, 5 deletions
diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp
index a74cab2d..4cfa9c83 100644
--- a/test/routines/level3/xgemm.hpp
+++ b/test/routines/level3/xgemm.hpp
@@ -67,7 +67,6 @@ class TestXgemm {
args.c_size = GetSizeC(args);
// Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels
- auto queue_plain = queue();
if (V != 0) {
const auto device = queue.GetDevice();
const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests
@@ -78,9 +77,16 @@ class TestXgemm {
// Sets the size of the temporary buffer (optional argument to GEMM)
auto temp_buffer_size = size_t{0};
- GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
- args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
- &queue_plain, temp_buffer_size);
+ #ifdef OPENCL_API
+ auto queue_plain = queue();
+ GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
+ args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
+ &queue_plain, temp_buffer_size);
+ #elif CUDA_API
+ GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
+ args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
+ queue.GetDevice()(), temp_buffer_size);
+ #endif
args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero
}
@@ -117,7 +123,7 @@ class TestXgemm {
buffers.a_mat(), args.a_offset, args.a_ld,
buffers.b_mat(), args.b_offset, args.b_ld, args.beta,
buffers.c_mat(), args.c_offset, args.c_ld,
- queue.GetContext()(), queue.GetDevice()());
+ queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer
cuStreamSynchronize(queue());
#endif
return status;