diff options
Diffstat (limited to 'src/routines/level2/xgemv.cc')
-rw-r--r-- | src/routines/level2/xgemv.cc | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/src/routines/level2/xgemv.cc b/src/routines/level2/xgemv.cc index f8985038..4d6437a2 100644 --- a/src/routines/level2/xgemv.cc +++ b/src/routines/level2/xgemv.cc @@ -134,6 +134,12 @@ StatusCode Xgemv<T>::MatVec(const Layout layout, const Transpose a_transpose, local_size = db_["WGS3"]; } + // Upload the scalar arguments as constant buffers to the device (needed for half-precision) + auto alpha_buffer = Buffer<T>(context_, 1); + auto beta_buffer = Buffer<T>(context_, 1); + alpha_buffer.Write(queue_, 1, &alpha); + beta_buffer.Write(queue_, 1, &beta); + // Retrieves the Xgemv kernel from the compiled binary try { const auto program = GetProgramFromCache(); @@ -142,8 +148,8 @@ StatusCode Xgemv<T>::MatVec(const Layout layout, const Transpose a_transpose, // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(m_real)); kernel.SetArgument(1, static_cast<int>(n_real)); - kernel.SetArgument(2, alpha); - kernel.SetArgument(3, beta); + kernel.SetArgument(2, alpha_buffer()); + kernel.SetArgument(3, beta_buffer()); kernel.SetArgument(4, static_cast<int>(a_rotated)); kernel.SetArgument(5, a_buffer()); kernel.SetArgument(6, static_cast<int>(a_offset)); |