diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/correctness/routines/level1/xnrm2.cc | 2 | ||||
-rw-r--r-- | test/performance/routines/level1/xnrm2.cc | 6 | ||||
-rw-r--r-- | test/wrapper_clblas.h | 32 |
3 files changed, 37 insertions, 3 deletions
diff --git a/test/correctness/routines/level1/xnrm2.cc b/test/correctness/routines/level1/xnrm2.cc index 8238e868..97fb0ad6 100644 --- a/test/correctness/routines/level1/xnrm2.cc +++ b/test/correctness/routines/level1/xnrm2.cc @@ -20,6 +20,8 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests<clblast::TestXnrm2<float>, float, float>(argc, argv, false, "SNRM2"); clblast::RunTests<clblast::TestXnrm2<double>, double, double>(argc, argv, true, "DNRM2"); + clblast::RunTests<clblast::TestXnrm2<float2>, float2, float2>(argc, argv, true, "ScNRM2"); + clblast::RunTests<clblast::TestXnrm2<double2>, double2, double2>(argc, argv, true, "DzNRM2"); return 0; } diff --git a/test/performance/routines/level1/xnrm2.cc b/test/performance/routines/level1/xnrm2.cc index d5ae348b..db6ec9ad 100644 --- a/test/performance/routines/level1/xnrm2.cc +++ b/test/performance/routines/level1/xnrm2.cc @@ -24,8 +24,10 @@ int main(int argc, char *argv[]) { clblast::RunClient<clblast::TestXnrm2<float>, float, float>(argc, argv); break; case clblast::Precision::kDouble: clblast::RunClient<clblast::TestXnrm2<double>, double, double>(argc, argv); break; - case clblast::Precision::kComplexSingle: throw std::runtime_error("Unsupported precision mode"); - case clblast::Precision::kComplexDouble: throw std::runtime_error("Unsupported precision mode"); + case clblast::Precision::kComplexSingle: + clblast::RunClient<clblast::TestXnrm2<float2>, float2, float2>(argc, argv); break; + case clblast::Precision::kComplexDouble: + clblast::RunClient<clblast::TestXnrm2<double2>, double2, double2>(argc, argv); break; } return 0; } diff --git a/test/wrapper_clblas.h b/test/wrapper_clblas.h index 501f0bc5..37d9eee5 100644 --- a/test/wrapper_clblas.h +++ b/test/wrapper_clblas.h @@ -350,7 +350,7 @@ clblasStatus clblasXdotc<double2>(const size_t n, num_queues, queues, num_wait_events, wait_events, events); } -// Forwards the clBLAS calls for SNRM2/DNRM2 +// Forwards the clBLAS calls for SNRM2/DNRM2/ScNRM2/DzNRM2 template <typename T> clblasStatus clblasXnrm2(const size_t n, cl_mem nrm2_buffer, const size_t nrm2_offset, @@ -387,6 +387,36 @@ clblasStatus clblasXnrm2<double>(const size_t n, scratch_buffer(), num_queues, queues, num_wait_events, wait_events, events); } +template <> +clblasStatus clblasXnrm2<float2>(const size_t n, + cl_mem nrm2_buffer, const size_t nrm2_offset, + const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, + cl_uint num_queues, cl_command_queue *queues, + cl_uint num_wait_events, const cl_event *wait_events, cl_event *events) { + auto queue = Queue(queues[0]); + auto context = queue.GetContext(); + auto scratch_buffer = Buffer<float2>(context, n*x_inc + x_offset); + return clblasScnrm2(n, + nrm2_buffer, nrm2_offset, + x_buffer, x_offset, static_cast<int>(x_inc), + scratch_buffer(), + num_queues, queues, num_wait_events, wait_events, events); +} +template <> +clblasStatus clblasXnrm2<double2>(const size_t n, + cl_mem nrm2_buffer, const size_t nrm2_offset, + const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, + cl_uint num_queues, cl_command_queue *queues, + cl_uint num_wait_events, const cl_event *wait_events, cl_event *events) { + auto queue = Queue(queues[0]); + auto context = queue.GetContext(); + auto scratch_buffer = Buffer<double2>(context, n*x_inc + x_offset); + return clblasDznrm2(n, + nrm2_buffer, nrm2_offset, + x_buffer, x_offset, static_cast<int>(x_inc), + scratch_buffer(), + num_queues, queues, num_wait_events, wait_events, events); +} // ================================================================================================= // BLAS level-2 (matrix-vector) routines |