diff options
-rw-r--r-- | include/internal/routines/level1/xaxpy.h | 1 | ||||
-rw-r--r-- | src/kernels/common.opencl | 7 | ||||
-rw-r--r-- | src/kernels/level1/xaxpy.opencl | 9 | ||||
-rw-r--r-- | src/routines/level1/xaxpy.cc | 8 |
4 files changed, 12 insertions, 13 deletions
diff --git a/include/internal/routines/level1/xaxpy.h b/include/internal/routines/level1/xaxpy.h index bc00c8e3..03771d53 100644 --- a/include/internal/routines/level1/xaxpy.h +++ b/include/internal/routines/level1/xaxpy.h @@ -29,6 +29,7 @@ class Xaxpy: public Routine<T> { using Routine<T>::source_string_; using Routine<T>::queue_; using Routine<T>::event_; + using Routine<T>::context_; using Routine<T>::GetProgramFromCache; using Routine<T>::TestVectorX; using Routine<T>::TestVectorY; diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index df9ec35b..f0da5a47 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -100,13 +100,6 @@ R"( #define SMALLEST -1.0e37 #endif -// Kernel argument scalar -#if PRECISION == 16 - typedef float realarg; -#else - typedef real realarg; -#endif - // Single-element version of a complex number #if PRECISION == 3232 typedef float singlereal; diff --git a/src/kernels/level1/xaxpy.opencl b/src/kernels/level1/xaxpy.opencl index 58b7a196..e0efadc1 100644 --- a/src/kernels/level1/xaxpy.opencl +++ b/src/kernels/level1/xaxpy.opencl @@ -23,10 +23,10 @@ R"( // Full version of the kernel with offsets and strided accesses __attribute__((reqd_work_group_size(WGS, 1, 1))) -__kernel void Xaxpy(const int n, const realarg arg_alpha, +__kernel void Xaxpy(const int n, const __constant real* restrict arg_alpha, const __global real* restrict xgm, const int x_offset, const int x_inc, __global real* ygm, const int y_offset, const int y_inc) { - const real alpha = (real)arg_alpha; + const real alpha = arg_alpha[0]; // Loops over the work that needs to be done (allows for an arbitrary number of threads) #pragma unroll @@ -41,10 +41,11 @@ __kernel void Xaxpy(const int n, const realarg arg_alpha, // Faster version of the kernel without offsets and strided accesses. Also assumes that 'n' is // dividable by 'VW', 'WGS' and 'WPT'. __attribute__((reqd_work_group_size(WGS, 1, 1))) -__kernel void XaxpyFast(const int n, const realarg arg_alpha, +__kernel void XaxpyFast(const int n, const __constant real* restrict arg_alpha, const __global realV* restrict xgm, __global realV* ygm) { - const real alpha = (real)arg_alpha; + const real alpha = arg_alpha[0]; + #pragma unroll for (int w=0; w<WPT; ++w) { const int id = w*get_global_size(0) + get_global_id(0); diff --git a/src/routines/level1/xaxpy.cc b/src/routines/level1/xaxpy.cc index b7956bf2..66aa2336 100644 --- a/src/routines/level1/xaxpy.cc +++ b/src/routines/level1/xaxpy.cc @@ -68,16 +68,20 @@ StatusCode Xaxpy<T>::DoAxpy(const size_t n, const T alpha, const auto program = GetProgramFromCache(); auto kernel = Kernel(program, kernel_name); + // Upload the scalar argument as a constant buffer to the device (needed for half-precision) + auto alpha_buffer = Buffer<T>(context_, 1); + alpha_buffer.Write(queue_, 1, &alpha); + // Sets the kernel arguments if (use_fast_kernel) { kernel.SetArgument(0, static_cast<int>(n)); - kernel.SetArgument(1, static_cast<typename RealArg<T>::Type>(alpha)); + kernel.SetArgument(1, alpha_buffer()); kernel.SetArgument(2, x_buffer()); kernel.SetArgument(3, y_buffer()); } else { kernel.SetArgument(0, static_cast<int>(n)); - kernel.SetArgument(1, static_cast<typename RealArg<T>::Type>(alpha)); + kernel.SetArgument(1, alpha_buffer()); kernel.SetArgument(2, x_buffer()); kernel.SetArgument(3, static_cast<int>(x_offset)); kernel.SetArgument(4, static_cast<int>(x_inc)); |