summaryrefslogtreecommitdiff
path: root/src/routines/level3
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-10-06 19:51:12 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-10-06 19:51:12 +0200
commita3e67f2be2ea9f964c8077d379ca522c6c439036 (patch)
tree71dfd13e2dc6f30fa1913f17f4d6a18e7b61ae9e /src/routines/level3
parent7052a00a3edc0d37444c88914ece4c468c3e4e96 (diff)
Added a kernel selection database to select between the direct and indirect GEMM kernels
Diffstat (limited to 'src/routines/level3')
-rw-r--r--src/routines/level3/xgemm.cpp11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index 93f5d30c..9d912374 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -22,7 +22,8 @@ namespace clblast {
// Constructor: forwards to base class constructor
template <typename T>
Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name):
- Routine(queue, event, name, {"Copy","Pad","Transpose","Padtranspose","Xgemm", "XgemmDirect"},
+ Routine(queue, event, name,
+ {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"},
PrecisionValue<T>()) {
source_string_ =
#include "../../kernels/level3/level3.opencl"
@@ -102,15 +103,15 @@ StatusCode Xgemm<T>::DoGemm(const Layout layout,
status = TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld);
if (ErrorIn(status)) { return status; }
- // Optionally runs the direct version of GEMM. TODO: Set this based on the arguments
- const auto do_gemm_direct = true; // for now, for testing
- if (do_gemm_direct) {
+ // Selects which version of GEMM to run
+ const auto do_gemm_direct = (m * n * k < db_["XGEMM_MIN_INDIRECT_SIZE"]);
+ if (do_gemm_direct) { // for small sizes (single kernel)
return GemmDirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
c_buffer, c_offset, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
}
- else {
+ else { // for larger sizes (pre/post-processing plus a very fast kernel)
return GemmIndirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
c_buffer, c_offset, c_ld,