summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/internal/routines/level1/xaxpy.h1
-rw-r--r--src/kernels/common.opencl7
-rw-r--r--src/kernels/level1/xaxpy.opencl9
-rw-r--r--src/routines/level1/xaxpy.cc8
4 files changed, 12 insertions, 13 deletions
diff --git a/include/internal/routines/level1/xaxpy.h b/include/internal/routines/level1/xaxpy.h
index bc00c8e3..03771d53 100644
--- a/include/internal/routines/level1/xaxpy.h
+++ b/include/internal/routines/level1/xaxpy.h
@@ -29,6 +29,7 @@ class Xaxpy: public Routine<T> {
using Routine<T>::source_string_;
using Routine<T>::queue_;
using Routine<T>::event_;
+ using Routine<T>::context_;
using Routine<T>::GetProgramFromCache;
using Routine<T>::TestVectorX;
using Routine<T>::TestVectorY;
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl
index df9ec35b..f0da5a47 100644
--- a/src/kernels/common.opencl
+++ b/src/kernels/common.opencl
@@ -100,13 +100,6 @@ R"(
#define SMALLEST -1.0e37
#endif
-// Kernel argument scalar
-#if PRECISION == 16
- typedef float realarg;
-#else
- typedef real realarg;
-#endif
-
// Single-element version of a complex number
#if PRECISION == 3232
typedef float singlereal;
diff --git a/src/kernels/level1/xaxpy.opencl b/src/kernels/level1/xaxpy.opencl
index 58b7a196..e0efadc1 100644
--- a/src/kernels/level1/xaxpy.opencl
+++ b/src/kernels/level1/xaxpy.opencl
@@ -23,10 +23,10 @@ R"(
// Full version of the kernel with offsets and strided accesses
__attribute__((reqd_work_group_size(WGS, 1, 1)))
-__kernel void Xaxpy(const int n, const realarg arg_alpha,
+__kernel void Xaxpy(const int n, const __constant real* restrict arg_alpha,
const __global real* restrict xgm, const int x_offset, const int x_inc,
__global real* ygm, const int y_offset, const int y_inc) {
- const real alpha = (real)arg_alpha;
+ const real alpha = arg_alpha[0];
// Loops over the work that needs to be done (allows for an arbitrary number of threads)
#pragma unroll
@@ -41,10 +41,11 @@ __kernel void Xaxpy(const int n, const realarg arg_alpha,
// Faster version of the kernel without offsets and strided accesses. Also assumes that 'n' is
// dividable by 'VW', 'WGS' and 'WPT'.
__attribute__((reqd_work_group_size(WGS, 1, 1)))
-__kernel void XaxpyFast(const int n, const realarg arg_alpha,
+__kernel void XaxpyFast(const int n, const __constant real* restrict arg_alpha,
const __global realV* restrict xgm,
__global realV* ygm) {
- const real alpha = (real)arg_alpha;
+ const real alpha = arg_alpha[0];
+
#pragma unroll
for (int w=0; w<WPT; ++w) {
const int id = w*get_global_size(0) + get_global_id(0);
diff --git a/src/routines/level1/xaxpy.cc b/src/routines/level1/xaxpy.cc
index b7956bf2..66aa2336 100644
--- a/src/routines/level1/xaxpy.cc
+++ b/src/routines/level1/xaxpy.cc
@@ -68,16 +68,20 @@ StatusCode Xaxpy<T>::DoAxpy(const size_t n, const T alpha,
const auto program = GetProgramFromCache();
auto kernel = Kernel(program, kernel_name);
+ // Upload the scalar argument as a constant buffer to the device (needed for half-precision)
+ auto alpha_buffer = Buffer<T>(context_, 1);
+ alpha_buffer.Write(queue_, 1, &alpha);
+
// Sets the kernel arguments
if (use_fast_kernel) {
kernel.SetArgument(0, static_cast<int>(n));
- kernel.SetArgument(1, static_cast<typename RealArg<T>::Type>(alpha));
+ kernel.SetArgument(1, alpha_buffer());
kernel.SetArgument(2, x_buffer());
kernel.SetArgument(3, y_buffer());
}
else {
kernel.SetArgument(0, static_cast<int>(n));
- kernel.SetArgument(1, static_cast<typename RealArg<T>::Type>(alpha));
+ kernel.SetArgument(1, alpha_buffer());
kernel.SetArgument(2, x_buffer());
kernel.SetArgument(3, static_cast<int>(x_offset));
kernel.SetArgument(4, static_cast<int>(x_inc));