From 5315b982a9f7a58b047021ab5038781f9e4ac482 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Wed, 3 Jan 2018 20:20:31 +0100 Subject: Added the temp-buffer to the GEMM testers and clients --- test/routines/level3/xgemm.hpp | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp index fe8cf7b9..a74cab2d 100644 --- a/test/routines/level3/xgemm.hpp +++ b/test/routines/level3/xgemm.hpp @@ -37,7 +37,8 @@ class TestXgemm { kArgAOffset, kArgBOffset, kArgCOffset, kArgAlpha, kArgBeta}; } - static std::vector BuffersIn() { return {kBufMatA, kBufMatB, kBufMatC}; } + static std::vector BuffersIn() { return {kBufMatA, kBufMatB, kBufMatC, + kBufMatAP}; } // used as temp buffer static std::vector BuffersOut() { return {kBufMatC}; } // Describes how to obtain the sizes of the buffers @@ -60,10 +61,27 @@ class TestXgemm { } // Describes how to set the sizes of all the buffers - static void SetSizes(Arguments &args) { + static void SetSizes(Arguments &args, Queue &queue) { args.a_size = GetSizeA(args); args.b_size = GetSizeB(args); args.c_size = GetSizeC(args); + + // Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels + auto queue_plain = queue(); + if (V != 0) { + const auto device = queue.GetDevice(); + const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests + const auto override_status = OverrideParameters(device(), "GemmRoutine", PrecisionValue(), + {{"XGEMM_MIN_INDIRECT_SIZE", switch_threshold}}); + if (override_status != StatusCode::kSuccess) { } + } + + // Sets the size of the temporary buffer (optional argument to GEMM) + auto temp_buffer_size = size_t{0}; + GemmTempBufferSize(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, + args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, + &queue_plain, temp_buffer_size); + args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero } // Describes what the default values of the leading dimensions of the matrices are @@ -83,13 +101,6 @@ class TestXgemm { // Describes how to run the CLBlast routine static StatusCode RunRoutine(const Arguments &args, Buffers &buffers, Queue &queue) { - if (V != 0) { - const auto device = queue.GetDevice(); - const auto switch_threshold = (V == 1) ? size_t{0} : size_t{1024 * 1024 * 1024}; // large enough for tests - const auto override_status = OverrideParameters(device(), "GemmRoutine", PrecisionValue(), - {{"XGEMM_MIN_INDIRECT_SIZE", switch_threshold}}); - if (override_status != StatusCode::kSuccess) { return override_status; } - } #ifdef OPENCL_API auto queue_plain = queue(); auto event = cl_event{}; @@ -98,7 +109,7 @@ class TestXgemm { buffers.a_mat(), args.a_offset, args.a_ld, buffers.b_mat(), args.b_offset, args.b_ld, args.beta, buffers.c_mat(), args.c_offset, args.c_ld, - &queue_plain, &event); + &queue_plain, &event, buffers.ap_mat()); // temp buffer if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } #elif CUDA_API auto status = Gemm(args.layout, args.a_transpose, args.b_transpose, -- cgit v1.2.3