summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/internal/routines/level2/xgemv.h1
-rw-r--r--src/kernels/level2/xgemv.opencl6
-rw-r--r--src/kernels/level2/xgemv_fast.opencl14
-rw-r--r--src/routines/level2/xgemv.cc10
-rw-r--r--src/tuning/xgemv.cc8
5 files changed, 31 insertions, 8 deletions
diff --git a/include/internal/routines/level2/xgemv.h b/include/internal/routines/level2/xgemv.h
index 0b2a8e66..875f936e 100644
--- a/include/internal/routines/level2/xgemv.h
+++ b/include/internal/routines/level2/xgemv.h
@@ -29,6 +29,7 @@ class Xgemv: public Routine<T> {
using Routine<T>::source_string_;
using Routine<T>::queue_;
using Routine<T>::event_;
+ using Routine<T>::context_;
using Routine<T>::GetProgramFromCache;
using Routine<T>::TestVectorX;
using Routine<T>::TestVectorY;
diff --git a/src/kernels/level2/xgemv.opencl b/src/kernels/level2/xgemv.opencl
index 30b131b4..65b4291f 100644
--- a/src/kernels/level2/xgemv.opencl
+++ b/src/kernels/level2/xgemv.opencl
@@ -211,13 +211,17 @@ inline real LoadMatrixA(const __global real* restrict agm, const int x, const in
// Full version of the kernel
__attribute__((reqd_work_group_size(WGS1, 1, 1)))
-__kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
+__kernel void Xgemv(const int m, const int n,
+ const __constant real* restrict arg_alpha,
+ const __constant real* restrict arg_beta,
const int a_rotated,
const __global real* restrict agm, const int a_offset, const int a_ld,
const __global real* restrict xgm, const int x_offset, const int x_inc,
__global real* ygm, const int y_offset, const int y_inc,
const int do_conjugate, const int parameter,
const int kl, const int ku) {
+ const real alpha = arg_alpha[0];
+ const real beta = arg_beta[0];
// Local memory for the vector X
__local real xlm[WGS1];
diff --git a/src/kernels/level2/xgemv_fast.opencl b/src/kernels/level2/xgemv_fast.opencl
index 61fdffa3..6a494e84 100644
--- a/src/kernels/level2/xgemv_fast.opencl
+++ b/src/kernels/level2/xgemv_fast.opencl
@@ -95,13 +95,18 @@ inline realVFR LoadMatrixAVFR(const __global realVFR* restrict agm, const int x,
// --> 'a_rotated' is 0
// --> 'do_conjugate' is 0
__attribute__((reqd_work_group_size(WGS2, 1, 1)))
-__kernel void XgemvFast(const int m, const int n, const real alpha, const real beta,
+__kernel void XgemvFast(const int m, const int n,
+ const __constant real* restrict arg_alpha,
+ const __constant real* restrict arg_beta,
const int a_rotated,
const __global realVF* restrict agm, const int a_offset, const int a_ld,
const __global real* restrict xgm, const int x_offset, const int x_inc,
__global real* ygm, const int y_offset, const int y_inc,
const int do_conjugate, const int parameter,
const int kl, const int ku) {
+ const real alpha = arg_alpha[0];
+ const real beta = arg_beta[0];
+
// Local memory for the vector X
__local real xlm[WGS2];
@@ -192,13 +197,18 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b
// --> 'a_rotated' is 1
// --> 'do_conjugate' is 0
__attribute__((reqd_work_group_size(WGS3, 1, 1)))
-__kernel void XgemvFastRot(const int m, const int n, const real alpha, const real beta,
+__kernel void XgemvFastRot(const int m, const int n,
+ const __constant real* restrict arg_alpha,
+ const __constant real* restrict arg_beta,
const int a_rotated,
const __global realVFR* restrict agm, const int a_offset, const int a_ld,
const __global real* restrict xgm, const int x_offset, const int x_inc,
__global real* ygm, const int y_offset, const int y_inc,
const int do_conjugate, const int parameter,
const int kl, const int ku) {
+ const real alpha = arg_alpha[0];
+ const real beta = arg_beta[0];
+
// Local memory for the vector X
__local real xlm[WGS3];
diff --git a/src/routines/level2/xgemv.cc b/src/routines/level2/xgemv.cc
index f8985038..4d6437a2 100644
--- a/src/routines/level2/xgemv.cc
+++ b/src/routines/level2/xgemv.cc
@@ -134,6 +134,12 @@ StatusCode Xgemv<T>::MatVec(const Layout layout, const Transpose a_transpose,
local_size = db_["WGS3"];
}
+ // Upload the scalar arguments as constant buffers to the device (needed for half-precision)
+ auto alpha_buffer = Buffer<T>(context_, 1);
+ auto beta_buffer = Buffer<T>(context_, 1);
+ alpha_buffer.Write(queue_, 1, &alpha);
+ beta_buffer.Write(queue_, 1, &beta);
+
// Retrieves the Xgemv kernel from the compiled binary
try {
const auto program = GetProgramFromCache();
@@ -142,8 +148,8 @@ StatusCode Xgemv<T>::MatVec(const Layout layout, const Transpose a_transpose,
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m_real));
kernel.SetArgument(1, static_cast<int>(n_real));
- kernel.SetArgument(2, alpha);
- kernel.SetArgument(3, beta);
+ kernel.SetArgument(2, alpha_buffer());
+ kernel.SetArgument(3, beta_buffer());
kernel.SetArgument(4, static_cast<int>(a_rotated));
kernel.SetArgument(5, a_buffer());
kernel.SetArgument(6, static_cast<int>(a_offset));
diff --git a/src/tuning/xgemv.cc b/src/tuning/xgemv.cc
index 43369c3b..6587dcf4 100644
--- a/src/tuning/xgemv.cc
+++ b/src/tuning/xgemv.cc
@@ -96,11 +96,13 @@ class TuneXgemv {
std::vector<T> &x_vec, std::vector<T> &y_vec,
std::vector<T> &a_mat, std::vector<T> &, std::vector<T> &,
std::vector<T> &) {
+ auto alpha_buffer = std::vector<T>{args.alpha};
+ auto beta_buffer = std::vector<T>{args.beta};
auto a_rotated = (V==3) ? 1 : 0;
tuner.AddArgumentScalar(static_cast<int>(args.m));
tuner.AddArgumentScalar(static_cast<int>(args.n));
- tuner.AddArgumentScalar(args.alpha);
- tuner.AddArgumentScalar(args.beta);
+ tuner.AddArgumentInput(alpha_buffer);
+ tuner.AddArgumentInput(beta_buffer);
tuner.AddArgumentScalar(static_cast<int>(a_rotated));
tuner.AddArgumentInput(a_mat);
tuner.AddArgumentScalar(0);
@@ -135,7 +137,7 @@ using double2 = clblast::double2;
template <int V>
void StartVariation(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
- case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode");
+ case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXgemv<half,V>, half>(argc, argv); break;
case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXgemv<float,V>, float>(argc, argv); break;
case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXgemv<double,V>, double>(argc, argv); break;
case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXgemv<float2,V>, float2>(argc, argv); break;