diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-19 15:57:44 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-03-19 15:57:44 +0100 |
commit | 2fd04dae83acb01933856e768a938db9ac808ce0 (patch) | |
tree | c8beec574d271f686f12211b993a47e462e55298 /src/routines/common.hpp | |
parent | 11bb30e72bf1f2f36380c0bae8593d2e27ce3bfe (diff) |
Added batched versions of the pad/copy/transpose kernels
Diffstat (limited to 'src/routines/common.hpp')
-rw-r--r-- | src/routines/common.hpp | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/src/routines/common.hpp b/src/routines/common.hpp index be6ac4ec..28a43da5 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -196,6 +196,70 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, } } +// Batched version of the above +template <typename T> +void PadCopyTransposeMatrixBatched(Queue &queue, const Device &device, + const Databases &db, + EventPointer event, const std::vector<Event> &waitForEvents, + const size_t src_one, const size_t src_two, + const size_t src_ld, const Buffer<int> &src_offsets, + const Buffer<T> &src, + const size_t dest_one, const size_t dest_two, + const size_t dest_ld, const Buffer<int> &dest_offsets, + const Buffer<T> &dest, + const Program &program, const bool do_pad, + const bool do_transpose, const bool do_conjugate, + const size_t batch_count) { + + // Determines the right kernel + auto kernel_name = std::string{}; + if (do_transpose) { + kernel_name = (do_pad) ? "TransposePadMatrixBatched" : "TransposeMatrixBatched"; + } + else { + kernel_name = (do_pad) ? "CopyPadMatrixBatched" : "CopyMatrixBatched"; + } + + // Retrieves the kernel from the compiled binary + auto kernel = Kernel(program, kernel_name); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast<int>(src_one)); + kernel.SetArgument(1, static_cast<int>(src_two)); + kernel.SetArgument(2, static_cast<int>(src_ld)); + kernel.SetArgument(3, src_offsets()); + kernel.SetArgument(4, src()); + kernel.SetArgument(5, static_cast<int>(dest_one)); + kernel.SetArgument(6, static_cast<int>(dest_two)); + kernel.SetArgument(7, static_cast<int>(dest_ld)); + kernel.SetArgument(8, dest_offsets()); + kernel.SetArgument(9, dest()); + if (do_pad) { + kernel.SetArgument(10, static_cast<int>(do_conjugate)); + } + + // Launches the kernel and returns the error code. Uses global and local thread sizes based on + // parameters in the database. + if (do_transpose) { + const auto global = std::vector<size_t>{ + Ceil(CeilDiv(dest_one, db["PADTRA_WPT"]), db["PADTRA_TILE"]), + Ceil(CeilDiv(dest_two, db["PADTRA_WPT"]), db["PADTRA_TILE"]), + batch_count + }; + const auto local = std::vector<size_t>{db["PADTRA_TILE"], db["PADTRA_TILE"], 1}; + RunKernel(kernel, queue, device, global, local, event, waitForEvents); + } + else { + const auto global = std::vector<size_t>{ + Ceil(CeilDiv(dest_one, db["PAD_WPTX"]), db["PAD_DIMX"]), + Ceil(CeilDiv(dest_two, db["PAD_WPTY"]), db["PAD_DIMY"]), + batch_count + }; + const auto local = std::vector<size_t>{db["PAD_DIMX"], db["PAD_DIMY"], 1}; + RunKernel(kernel, queue, device, global, local, event, waitForEvents); + } +} + // ================================================================================================= } // namespace clblast |