diff options
Diffstat (limited to 'src/routines/level3/xgemm.hpp')
-rw-r--r-- | src/routines/level3/xgemm.hpp | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index f0911d6a..25b1f5c9 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -29,6 +29,40 @@ class Xgemm: public Routine { static const bool b_want_rotated_; static const bool c_want_rotated_; + // Computes the size of the temporary GEMM buffer based on user-arguments + static size_t GetTempSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + const size_t mwg, const size_t nwg, const size_t kwg) { + + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that + bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; + size_t a_one, a_two, b_one, b_two, c_one, c_two; + ProcessArguments(layout, a_transpose, b_transpose, m, n, k, + a_one, a_two, b_one, b_two, c_one, c_two, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + + // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account + // whether the matrices need to be rotated or not for the kernel. + size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; + CalculateInternalDimensions(m, n, k, mwg, nwg, kwg, + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + + // Determines whether or not temporary matrices are needed + auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate); + auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate); + auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false); + + // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices + auto b_temp_offset = size_t{0}; + auto c_temp_offset = size_t{0}; + return ComputeTempSize(a_no_temp, b_no_temp, c_no_temp, + a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i, + b_temp_offset, c_temp_offset); + } + // Selects which version of GEMM to run static bool UseDirectKernel(const size_t m, const size_t n, const size_t k, const size_t min_indirect_size) { |