From 69ed46c8da69ee18338eca5102ead43410cc01b5 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Fri, 2 Feb 2018 21:18:37 +0100 Subject: Implemented the XHAD Hadamard product routine --- src/routines/levelx/xhad.cpp | 58 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) (limited to 'src/routines') diff --git a/src/routines/levelx/xhad.cpp b/src/routines/levelx/xhad.cpp index 46ae8031..da416cc7 100644 --- a/src/routines/levelx/xhad.cpp +++ b/src/routines/levelx/xhad.cpp @@ -24,7 +24,7 @@ template Xhad::Xhad(Queue &queue, EventPointer event, const std::string &name): Routine(queue, event, name, {"Xaxpy"}, PrecisionValue(), {}, { #include "../../kernels/level1/level1.opencl" -#include "../../kernels/level1/xaxpy.opencl" +#include "../../kernels/level1/xhad.opencl" }) { } @@ -45,6 +45,62 @@ void Xhad::DoHad(const size_t n, const T alpha, TestVectorY(n, y_buffer, y_offset, y_inc); TestVectorY(n, z_buffer, z_offset, z_inc); // TODO: Make a TestVectorZ function with error codes + // Determines whether or not the fast-version can be used + const auto use_faster_kernel = (x_offset == 0) && (x_inc == 1) && + (y_offset == 0) && (y_inc == 1) && + (z_offset == 0) && (z_inc == 1) && + IsMultiple(n, db_["WPT"]*db_["VW"]); + const auto use_fastest_kernel = use_faster_kernel && + IsMultiple(n, db_["WGS"]*db_["WPT"]*db_["VW"]); + + // If possible, run the fast-version of the kernel + const auto kernel_name = (use_fastest_kernel) ? "XhadFastest" : + (use_faster_kernel) ? "XhadFaster" : "Xhad"; + + // Retrieves the Xhad kernel from the compiled binary + auto kernel = Kernel(program_, kernel_name); + + // Sets the kernel arguments + if (use_faster_kernel || use_fastest_kernel) { + kernel.SetArgument(0, static_cast(n)); + kernel.SetArgument(1, GetRealArg(alpha)); + kernel.SetArgument(2, GetRealArg(beta)); + kernel.SetArgument(3, x_buffer()); + kernel.SetArgument(4, y_buffer()); + kernel.SetArgument(5, z_buffer()); + } + else { + kernel.SetArgument(0, static_cast(n)); + kernel.SetArgument(1, GetRealArg(alpha)); + kernel.SetArgument(2, GetRealArg(beta)); + kernel.SetArgument(3, x_buffer()); + kernel.SetArgument(4, static_cast(x_offset)); + kernel.SetArgument(5, static_cast(x_inc)); + kernel.SetArgument(6, y_buffer()); + kernel.SetArgument(7, static_cast(y_offset)); + kernel.SetArgument(8, static_cast(y_inc)); + kernel.SetArgument(9, z_buffer()); + kernel.SetArgument(10, static_cast(z_offset)); + kernel.SetArgument(11, static_cast(z_inc)); + } + + // Launches the kernel + if (use_fastest_kernel) { + auto global = std::vector{CeilDiv(n, db_["WPT"]*db_["VW"])}; + auto local = std::vector{db_["WGS"]}; + RunKernel(kernel, queue_, device_, global, local, event_); + } + else if (use_faster_kernel) { + auto global = std::vector{Ceil(CeilDiv(n, db_["WPT"]*db_["VW"]), db_["WGS"])}; + auto local = std::vector{db_["WGS"]}; + RunKernel(kernel, queue_, device_, global, local, event_); + } + else { + const auto n_ceiled = Ceil(n, db_["WGS"]*db_["WPT"]); + auto global = std::vector{n_ceiled/db_["WPT"]}; + auto local = std::vector{db_["WGS"]}; + RunKernel(kernel, queue_, device_, global, local, event_); + } } // ================================================================================================= -- cgit v1.2.3