diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-11-02 21:47:14 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-11-02 21:47:14 +0100 |
commit | 9b0a435fb00b845b875590be90acffcd4f3bb009 (patch) | |
tree | 754b523789ef717619b540925c97e7167ba28f06 /src/routines | |
parent | 73272ab97dbd5abe757f6558c9b89665c5ac99d0 (diff) |
Integrated the GEMM routine tuner for kernel selection; added first tuning results
Diffstat (limited to 'src/routines')
-rw-r--r-- | src/routines/level3/xgemm.cpp | 6 | ||||
-rw-r--r-- | src/routines/levelx/xgemmbatched.cpp | 2 |
2 files changed, 5 insertions, 3 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index a0063ee2..94392dd0 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -23,7 +23,7 @@ namespace clblast { template <typename T> Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): Routine(queue, event, name, - {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"}, + {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","GemmRoutine"}, PrecisionValue<T>(), {}, { #include "../../kernels/level3/level3.opencl" #include "../../kernels/level3/copy_fast.opencl" @@ -104,7 +104,9 @@ void Xgemm<T>::DoGemm(const Layout layout, // Selects which version of GEMM to run const auto m_n_k = static_cast<unsigned long long>(m) * static_cast<unsigned long long>(n) * static_cast<unsigned long long>(k); - const auto do_gemm_direct = (m_n_k < static_cast<unsigned long long>(db_["XGEMM_MIN_INDIRECT_SIZE"])); + const auto database_value = static_cast<unsigned long long>(db_["XGEMM_MIN_INDIRECT_SIZE"]); + const auto min_indirect_size = database_value * database_value * database_value; + const auto do_gemm_direct = (m_n_k < min_indirect_size); if (do_gemm_direct) { // for small sizes (single kernel) GemmDirect(m, n, k, alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp index 8a015e97..152e7194 100644 --- a/src/routines/levelx/xgemmbatched.cpp +++ b/src/routines/levelx/xgemmbatched.cpp @@ -23,7 +23,7 @@ namespace clblast { template <typename T> XgemmBatched<T>::XgemmBatched(Queue &queue, EventPointer event, const std::string &name): Routine(queue, event, name, - {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"}, + {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","GemmRoutine"}, PrecisionValue<T>(), {}, { #include "../../kernels/level3/level3.opencl" #include "../../kernels/level3/copy_fast.opencl" |