From ce069545d4b9ac32a094117de75919087a7bc21e Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sat, 6 Jan 2018 10:05:28 +0100 Subject: Added CUDA interface to get temporary-buffer size for GEMM routine --- test/routines/level3/xgemm.hpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) (limited to 'test') diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp index a74cab2d..4cfa9c83 100644 --- a/test/routines/level3/xgemm.hpp +++ b/test/routines/level3/xgemm.hpp @@ -67,7 +67,6 @@ class TestXgemm { args.c_size = GetSizeC(args); // Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels - auto queue_plain = queue(); if (V != 0) { const auto device = queue.GetDevice(); const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests @@ -78,9 +77,16 @@ class TestXgemm { // Sets the size of the temporary buffer (optional argument to GEMM) auto temp_buffer_size = size_t{0}; - GemmTempBufferSize(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, - args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, - &queue_plain, temp_buffer_size); + #ifdef OPENCL_API + auto queue_plain = queue(); + GemmTempBufferSize(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, + args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, + &queue_plain, temp_buffer_size); + #elif CUDA_API + GemmTempBufferSize(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, + args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, + queue.GetDevice()(), temp_buffer_size); + #endif args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero } @@ -117,7 +123,7 @@ class TestXgemm { buffers.a_mat(), args.a_offset, args.a_ld, buffers.b_mat(), args.b_offset, args.b_ld, args.beta, buffers.c_mat(), args.c_offset, args.c_ld, - queue.GetContext()(), queue.GetDevice()()); + queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer cuStreamSynchronize(queue()); #endif return status; -- cgit v1.2.3