summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-11-02 21:47:14 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-11-02 21:47:14 +0100
commit9b0a435fb00b845b875590be90acffcd4f3bb009 (patch)
tree754b523789ef717619b540925c97e7167ba28f06 /src/routines
parent73272ab97dbd5abe757f6558c9b89665c5ac99d0 (diff)
Integrated the GEMM routine tuner for kernel selection; added first tuning results
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/level3/xgemm.cpp6
-rw-r--r--src/routines/levelx/xgemmbatched.cpp2
2 files changed, 5 insertions, 3 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index a0063ee2..94392dd0 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -23,7 +23,7 @@ namespace clblast {
template <typename T>
Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name):
Routine(queue, event, name,
- {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"},
+ {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","GemmRoutine"},
PrecisionValue<T>(), {}, {
#include "../../kernels/level3/level3.opencl"
#include "../../kernels/level3/copy_fast.opencl"
@@ -104,7 +104,9 @@ void Xgemm<T>::DoGemm(const Layout layout,
// Selects which version of GEMM to run
const auto m_n_k = static_cast<unsigned long long>(m) * static_cast<unsigned long long>(n) *
static_cast<unsigned long long>(k);
- const auto do_gemm_direct = (m_n_k < static_cast<unsigned long long>(db_["XGEMM_MIN_INDIRECT_SIZE"]));
+ const auto database_value = static_cast<unsigned long long>(db_["XGEMM_MIN_INDIRECT_SIZE"]);
+ const auto min_indirect_size = database_value * database_value * database_value;
+ const auto do_gemm_direct = (m_n_k < min_indirect_size);
if (do_gemm_direct) { // for small sizes (single kernel)
GemmDirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp
index 8a015e97..152e7194 100644
--- a/src/routines/levelx/xgemmbatched.cpp
+++ b/src/routines/levelx/xgemmbatched.cpp
@@ -23,7 +23,7 @@ namespace clblast {
template <typename T>
XgemmBatched<T>::XgemmBatched(Queue &queue, EventPointer event, const std::string &name):
Routine(queue, event, name,
- {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"},
+ {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","GemmRoutine"},
PrecisionValue<T>(), {}, {
#include "../../kernels/level3/level3.opencl"
#include "../../kernels/level3/copy_fast.opencl"