summaryrefslogtreecommitdiff
path: root/src/routines/level3
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-04-15 12:53:32 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-04-15 12:53:32 +0200
commit93610a9cba7b057663eac31ed1d7cae1b24623c1 (patch)
tree4649dd57d66e652a092e5a2daff0db22fce517d5 /src/routines/level3
parentf14e6f87d2851936629071a2bb0c39d3a8b1a0e5 (diff)
Fixed some failing tests for GEMM and batched GEMM routines
Diffstat (limited to 'src/routines/level3')
-rw-r--r--src/routines/level3/xgemm.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index 6a314280..fd5a20db 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -59,13 +59,17 @@ void Xgemm<T>::DoGemm(const Layout layout,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { // optional arguments
+ // Two methods to choose from, select which one to run
+ const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
+ const auto gemm_kernel_id = (do_gemm_direct) ? 0 : db_["GEMMK"];
+
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
size_t a_one, a_two, b_one, b_two, c_one, c_two;
ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
a_one, a_two, b_one, b_two, c_one, c_two,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
- db_["GEMMK"]);
+ gemm_kernel_id);
// Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and
// their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the
@@ -79,7 +83,6 @@ void Xgemm<T>::DoGemm(const Layout layout,
TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld);
// Selects which version of GEMM to run
- const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
if (do_gemm_direct) { // for small sizes (single kernel)
GemmDirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,