diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-04-15 12:53:32 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-04-15 12:53:32 +0200 |
commit | 93610a9cba7b057663eac31ed1d7cae1b24623c1 (patch) | |
tree | 4649dd57d66e652a092e5a2daff0db22fce517d5 | |
parent | f14e6f87d2851936629071a2bb0c39d3a8b1a0e5 (diff) |
Fixed some failing tests for GEMM and batched GEMM routines
-rw-r--r-- | src/routines/level3/xgemm.cpp | 7 | ||||
-rw-r--r-- | src/routines/levelx/xgemmbatched.cpp | 7 | ||||
-rw-r--r-- | src/routines/levelx/xgemmstridedbatched.cpp | 7 | ||||
-rw-r--r-- | test/correctness/misc/override_parameters.cpp | 8 |
4 files changed, 19 insertions, 10 deletions
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 6a314280..fd5a20db 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -59,13 +59,17 @@ void Xgemm<T>::DoGemm(const Layout layout, const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { // optional arguments + // Two methods to choose from, select which one to run + const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]); + const auto gemm_kernel_id = (do_gemm_direct) ? 0 : db_["GEMMK"]; + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; size_t a_one, a_two, b_one, b_two, c_one, c_two; ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, - db_["GEMMK"]); + gemm_kernel_id); // Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and // their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the @@ -79,7 +83,6 @@ void Xgemm<T>::DoGemm(const Layout layout, TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld); // Selects which version of GEMM to run - const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]); if (do_gemm_direct) { // for small sizes (single kernel) GemmDirect(m, n, k, alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp index 2caf1e39..2bbc5007 100644 --- a/src/routines/levelx/xgemmbatched.cpp +++ b/src/routines/levelx/xgemmbatched.cpp @@ -65,13 +65,17 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans throw BLASError(StatusCode::kInvalidBatchCount); } + // Two methods to choose from, select which one to run + const auto do_gemm_direct = Xgemm<T>::UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]); + const auto gemm_kernel_id = (do_gemm_direct) ? 0 : db_["GEMMK"]; + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; size_t a_one, a_two, b_one, b_two, c_one, c_two; Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, - db_["GEMMK"]); + gemm_kernel_id); // Tests the matrices for validity for (auto batch = size_t{0}; batch < batch_count; ++batch) { @@ -97,7 +101,6 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans } // Selects which version of the batched GEMM to run - const auto do_gemm_direct = Xgemm<T>::UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]); if (do_gemm_direct) { // single generic kernel BatchedGemmDirect(m, n, k, alphas_device, a_buffer, a_offsets_int, a_ld, b_buffer, b_offsets_int, b_ld, diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp index 8408f75a..30c161cc 100644 --- a/src/routines/levelx/xgemmstridedbatched.cpp +++ b/src/routines/levelx/xgemmstridedbatched.cpp @@ -61,13 +61,17 @@ void XgemmStridedBatched<T>::DoGemmStridedBatched(const Layout layout, const Tra throw BLASError(StatusCode::kInvalidBatchCount); } + // Two methods to choose from, select which one to run + const auto do_gemm_direct = Xgemm<T>::UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]); + const auto gemm_kernel_id = (do_gemm_direct) ? 0 : db_["GEMMK"]; + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; size_t a_one, a_two, b_one, b_two, c_one, c_two; Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k, a_one, a_two, b_one, b_two, c_one, c_two, a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, - db_["GEMMK"]); + gemm_kernel_id); // Tests the matrices for validity for (auto batch = size_t{0}; batch < batch_count; ++batch) { @@ -77,7 +81,6 @@ void XgemmStridedBatched<T>::DoGemmStridedBatched(const Layout layout, const Tra } // Selects which version of the batched GEMM to run - const auto do_gemm_direct = Xgemm<T>::UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);; if (do_gemm_direct) { // single generic kernel BatchedGemmDirect(m, n, k, alpha, a_buffer, a_offset, a_ld, a_stride, diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp index 1edfb2ba..54229c5e 100644 --- a/test/correctness/misc/override_parameters.cpp +++ b/test/correctness/misc/override_parameters.cpp @@ -35,12 +35,12 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st const auto kernel_name = std::string{"Xgemm"}; const auto precision = PrecisionValue<T>(); const auto valid_settings = std::vector<std::unordered_map<std::string,size_t>>{ - { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, - { {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, - { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, + { {"GEMMK",0}, {"KREG",1}, {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, + { {"GEMMK",0}, {"KREG",1}, {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, + { {"GEMMK",0}, {"KREG",1}, {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, }; const auto invalid_settings = std::vector<std::unordered_map<std::string,size_t>>{ - { {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} }, + { {"GEMMK",0}, {"KREG",1}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} }, }; // Retrieves the arguments |