diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-06 16:08:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-06 16:08:27 +0100 |
commit | a7ccce196915db7a3b7ea7fe8ea9048f5b1204c6 (patch) | |
tree | 27dd8771ee6f913b5a2dabfae115bbe7fbc9d979 /test/routines/level3/xgemm.hpp | |
parent | 8040a4e355bdf6531eb9c4c5ae1fe4f792899d24 (diff) | |
parent | ad197da08da7ef414db90dbb97e92c575363c280 (diff) |
Merge pull request #238 from CNugteren/gemm_api_with_temp_buffer
GEMM API with optional temp buffer
Diffstat (limited to 'test/routines/level3/xgemm.hpp')
-rw-r--r-- | test/routines/level3/xgemm.hpp | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp index fe8cf7b9..4cfa9c83 100644 --- a/test/routines/level3/xgemm.hpp +++ b/test/routines/level3/xgemm.hpp @@ -37,7 +37,8 @@ class TestXgemm { kArgAOffset, kArgBOffset, kArgCOffset, kArgAlpha, kArgBeta}; } - static std::vector<std::string> BuffersIn() { return {kBufMatA, kBufMatB, kBufMatC}; } + static std::vector<std::string> BuffersIn() { return {kBufMatA, kBufMatB, kBufMatC, + kBufMatAP}; } // used as temp buffer static std::vector<std::string> BuffersOut() { return {kBufMatC}; } // Describes how to obtain the sizes of the buffers @@ -60,10 +61,33 @@ class TestXgemm { } // Describes how to set the sizes of all the buffers - static void SetSizes(Arguments<T> &args) { + static void SetSizes(Arguments<T> &args, Queue &queue) { args.a_size = GetSizeA(args); args.b_size = GetSizeB(args); args.c_size = GetSizeC(args); + + // Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels + if (V != 0) { + const auto device = queue.GetDevice(); + const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests + const auto override_status = OverrideParameters(device(), "GemmRoutine", PrecisionValue<T>(), + {{"XGEMM_MIN_INDIRECT_SIZE", switch_threshold}}); + if (override_status != StatusCode::kSuccess) { } + } + + // Sets the size of the temporary buffer (optional argument to GEMM) + auto temp_buffer_size = size_t{0}; + #ifdef OPENCL_API + auto queue_plain = queue(); + GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, + args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, + &queue_plain, temp_buffer_size); + #elif CUDA_API + GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, + args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, + queue.GetDevice()(), temp_buffer_size); + #endif + args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero } // Describes what the default values of the leading dimensions of the matrices are @@ -83,13 +107,6 @@ class TestXgemm { // Describes how to run the CLBlast routine static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) { - if (V != 0) { - const auto device = queue.GetDevice(); - const auto switch_threshold = (V == 1) ? size_t{0} : size_t{1024 * 1024 * 1024}; // large enough for tests - const auto override_status = OverrideParameters(device(), "GemmRoutine", PrecisionValue<T>(), - {{"XGEMM_MIN_INDIRECT_SIZE", switch_threshold}}); - if (override_status != StatusCode::kSuccess) { return override_status; } - } #ifdef OPENCL_API auto queue_plain = queue(); auto event = cl_event{}; @@ -98,7 +115,7 @@ class TestXgemm { buffers.a_mat(), args.a_offset, args.a_ld, buffers.b_mat(), args.b_offset, args.b_ld, args.beta, buffers.c_mat(), args.c_offset, args.c_ld, - &queue_plain, &event); + &queue_plain, &event, buffers.ap_mat()); // temp buffer if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } #elif CUDA_API auto status = Gemm(args.layout, args.a_transpose, args.b_transpose, @@ -106,7 +123,7 @@ class TestXgemm { buffers.a_mat(), args.a_offset, args.a_ld, buffers.b_mat(), args.b_offset, args.b_ld, args.beta, buffers.c_mat(), args.c_offset, args.c_ld, - queue.GetContext()(), queue.GetDevice()()); + queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer cuStreamSynchronize(queue()); #endif return status; |