summaryrefslogtreecommitdiff
path: root/src/routines/level3/xgemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/routines/level3/xgemm.hpp')
-rw-r--r--src/routines/level3/xgemm.hpp34
1 files changed, 34 insertions, 0 deletions
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp
index f0911d6a..25b1f5c9 100644
--- a/src/routines/level3/xgemm.hpp
+++ b/src/routines/level3/xgemm.hpp
@@ -29,6 +29,40 @@ class Xgemm: public Routine {
static const bool b_want_rotated_;
static const bool c_want_rotated_;
+ // Computes the size of the temporary GEMM buffer based on user-arguments
+ static size_t GetTempSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ const size_t mwg, const size_t nwg, const size_t kwg) {
+
+ // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
+ bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
+ size_t a_one, a_two, b_one, b_two, c_one, c_two;
+ ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
+
+ // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
+ // whether the matrices need to be rotated or not for the kernel.
+ size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
+ CalculateInternalDimensions(m, n, k, mwg, nwg, kwg,
+ a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
+
+ // Determines whether or not temporary matrices are needed
+ auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate);
+ auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate);
+ auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false);
+
+ // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
+ auto b_temp_offset = size_t{0};
+ auto c_temp_offset = size_t{0};
+ return ComputeTempSize(a_no_temp, b_no_temp, c_no_temp,
+ a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i,
+ b_temp_offset, c_temp_offset);
+ }
+
// Selects which version of GEMM to run
static bool UseDirectKernel(const size_t m, const size_t n, const size_t k,
const size_t min_indirect_size) {