diff options
-rw-r--r-- | include/internal/routines/level2/xgemv.h | 1 | ||||
-rw-r--r-- | src/kernels/level2/xgemv.opencl | 6 | ||||
-rw-r--r-- | src/kernels/level2/xgemv_fast.opencl | 14 | ||||
-rw-r--r-- | src/routines/level2/xgemv.cc | 10 | ||||
-rw-r--r-- | src/tuning/xgemv.cc | 8 |
5 files changed, 31 insertions, 8 deletions
diff --git a/include/internal/routines/level2/xgemv.h b/include/internal/routines/level2/xgemv.h index 0b2a8e66..875f936e 100644 --- a/include/internal/routines/level2/xgemv.h +++ b/include/internal/routines/level2/xgemv.h @@ -29,6 +29,7 @@ class Xgemv: public Routine<T> { using Routine<T>::source_string_; using Routine<T>::queue_; using Routine<T>::event_; + using Routine<T>::context_; using Routine<T>::GetProgramFromCache; using Routine<T>::TestVectorX; using Routine<T>::TestVectorY; diff --git a/src/kernels/level2/xgemv.opencl b/src/kernels/level2/xgemv.opencl index 30b131b4..65b4291f 100644 --- a/src/kernels/level2/xgemv.opencl +++ b/src/kernels/level2/xgemv.opencl @@ -211,13 +211,17 @@ inline real LoadMatrixA(const __global real* restrict agm, const int x, const in // Full version of the kernel __attribute__((reqd_work_group_size(WGS1, 1, 1))) -__kernel void Xgemv(const int m, const int n, const real alpha, const real beta, +__kernel void Xgemv(const int m, const int n, + const __constant real* restrict arg_alpha, + const __constant real* restrict arg_beta, const int a_rotated, const __global real* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, __global real* ygm, const int y_offset, const int y_inc, const int do_conjugate, const int parameter, const int kl, const int ku) { + const real alpha = arg_alpha[0]; + const real beta = arg_beta[0]; // Local memory for the vector X __local real xlm[WGS1]; diff --git a/src/kernels/level2/xgemv_fast.opencl b/src/kernels/level2/xgemv_fast.opencl index 61fdffa3..6a494e84 100644 --- a/src/kernels/level2/xgemv_fast.opencl +++ b/src/kernels/level2/xgemv_fast.opencl @@ -95,13 +95,18 @@ inline realVFR LoadMatrixAVFR(const __global realVFR* restrict agm, const int x, // --> 'a_rotated' is 0 // --> 'do_conjugate' is 0 __attribute__((reqd_work_group_size(WGS2, 1, 1))) -__kernel void XgemvFast(const int m, const int n, const real alpha, const real beta, +__kernel void XgemvFast(const int m, const int n, + const __constant real* restrict arg_alpha, + const __constant real* restrict arg_beta, const int a_rotated, const __global realVF* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, __global real* ygm, const int y_offset, const int y_inc, const int do_conjugate, const int parameter, const int kl, const int ku) { + const real alpha = arg_alpha[0]; + const real beta = arg_beta[0]; + // Local memory for the vector X __local real xlm[WGS2]; @@ -192,13 +197,18 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b // --> 'a_rotated' is 1 // --> 'do_conjugate' is 0 __attribute__((reqd_work_group_size(WGS3, 1, 1))) -__kernel void XgemvFastRot(const int m, const int n, const real alpha, const real beta, +__kernel void XgemvFastRot(const int m, const int n, + const __constant real* restrict arg_alpha, + const __constant real* restrict arg_beta, const int a_rotated, const __global realVFR* restrict agm, const int a_offset, const int a_ld, const __global real* restrict xgm, const int x_offset, const int x_inc, __global real* ygm, const int y_offset, const int y_inc, const int do_conjugate, const int parameter, const int kl, const int ku) { + const real alpha = arg_alpha[0]; + const real beta = arg_beta[0]; + // Local memory for the vector X __local real xlm[WGS3]; diff --git a/src/routines/level2/xgemv.cc b/src/routines/level2/xgemv.cc index f8985038..4d6437a2 100644 --- a/src/routines/level2/xgemv.cc +++ b/src/routines/level2/xgemv.cc @@ -134,6 +134,12 @@ StatusCode Xgemv<T>::MatVec(const Layout layout, const Transpose a_transpose, local_size = db_["WGS3"]; } + // Upload the scalar arguments as constant buffers to the device (needed for half-precision) + auto alpha_buffer = Buffer<T>(context_, 1); + auto beta_buffer = Buffer<T>(context_, 1); + alpha_buffer.Write(queue_, 1, &alpha); + beta_buffer.Write(queue_, 1, &beta); + // Retrieves the Xgemv kernel from the compiled binary try { const auto program = GetProgramFromCache(); @@ -142,8 +148,8 @@ StatusCode Xgemv<T>::MatVec(const Layout layout, const Transpose a_transpose, // Sets the kernel arguments kernel.SetArgument(0, static_cast<int>(m_real)); kernel.SetArgument(1, static_cast<int>(n_real)); - kernel.SetArgument(2, alpha); - kernel.SetArgument(3, beta); + kernel.SetArgument(2, alpha_buffer()); + kernel.SetArgument(3, beta_buffer()); kernel.SetArgument(4, static_cast<int>(a_rotated)); kernel.SetArgument(5, a_buffer()); kernel.SetArgument(6, static_cast<int>(a_offset)); diff --git a/src/tuning/xgemv.cc b/src/tuning/xgemv.cc index 43369c3b..6587dcf4 100644 --- a/src/tuning/xgemv.cc +++ b/src/tuning/xgemv.cc @@ -96,11 +96,13 @@ class TuneXgemv { std::vector<T> &x_vec, std::vector<T> &y_vec, std::vector<T> &a_mat, std::vector<T> &, std::vector<T> &, std::vector<T> &) { + auto alpha_buffer = std::vector<T>{args.alpha}; + auto beta_buffer = std::vector<T>{args.beta}; auto a_rotated = (V==3) ? 1 : 0; tuner.AddArgumentScalar(static_cast<int>(args.m)); tuner.AddArgumentScalar(static_cast<int>(args.n)); - tuner.AddArgumentScalar(args.alpha); - tuner.AddArgumentScalar(args.beta); + tuner.AddArgumentInput(alpha_buffer); + tuner.AddArgumentInput(beta_buffer); tuner.AddArgumentScalar(static_cast<int>(a_rotated)); tuner.AddArgumentInput(a_mat); tuner.AddArgumentScalar(0); @@ -135,7 +137,7 @@ using double2 = clblast::double2; template <int V> void StartVariation(int argc, char *argv[]) { switch(clblast::GetPrecision(argc, argv)) { - case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); + case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXgemv<half,V>, half>(argc, argv); break; case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXgemv<float,V>, float>(argc, argv); break; case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXgemv<double,V>, double>(argc, argv); break; case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXgemv<float2,V>, float2>(argc, argv); break; |