diff options
Diffstat (limited to 'src/routines/level3/xgemm.cpp')
-rw-r--r-- | src/routines/level3/xgemm.cpp | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 93f5d30c..9d912374 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -22,7 +22,8 @@ namespace clblast { // Constructor: forwards to base class constructor template <typename T> Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, {"Copy","Pad","Transpose","Padtranspose","Xgemm", "XgemmDirect"}, + Routine(queue, event, name, + {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"}, PrecisionValue<T>()) { source_string_ = #include "../../kernels/level3/level3.opencl" @@ -102,15 +103,15 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout, status = TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld); if (ErrorIn(status)) { return status; } - // Optionally runs the direct version of GEMM. TODO: Set this based on the arguments - const auto do_gemm_direct = true; // for now, for testing - if (do_gemm_direct) { + // Selects which version of GEMM to run + const auto do_gemm_direct = (m * n * k < db_["XGEMM_MIN_INDIRECT_SIZE"]); + if (do_gemm_direct) { // for small sizes (single kernel) return GemmDirect(m, n, k, alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, c_buffer, c_offset, c_ld, a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); } - else { + else { // for larger sizes (pre/post-processing plus a very fast kernel) return GemmIndirect(m, n, k, alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, c_buffer, c_offset, c_ld, |