summaryrefslogtreecommitdiff
path: root/src/routines/common.hpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-08 21:07:01 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-08 21:07:01 +0100
commit99a4df88a6d808ea77c9116ce63621503c00b57a (patch)
treed8f8bc1b3884c0340df9f6d95b4837ed3dff8deb /src/routines/common.hpp
parent13f0f6fc6e612a5f77c6fd78b983f1b2bb8e36b6 (diff)
Implemented the in-direct version of the strided-batched GEMM kernel
Diffstat (limited to 'src/routines/common.hpp')
-rw-r--r--src/routines/common.hpp66
1 files changed, 66 insertions, 0 deletions
diff --git a/src/routines/common.hpp b/src/routines/common.hpp
index 06d001d9..6cbe1e1b 100644
--- a/src/routines/common.hpp
+++ b/src/routines/common.hpp
@@ -239,6 +239,72 @@ void PadCopyTransposeMatrixBatched(Queue &queue, const Device &device,
}
}
+// Batched version of the above
+template <typename T>
+void PadCopyTransposeMatrixStridedBatched(Queue &queue, const Device &device,
+ const Databases &db,
+ EventPointer event, const std::vector<Event> &waitForEvents,
+ const size_t src_one, const size_t src_two,
+ const size_t src_ld, const size_t src_offset,
+ const size_t src_stride, const Buffer<T> &src,
+ const size_t dest_one, const size_t dest_two,
+ const size_t dest_ld, const size_t dest_offset,
+ const size_t dest_stride, const Buffer<T> &dest,
+ const Program &program, const bool do_pad,
+ const bool do_transpose, const bool do_conjugate,
+ const size_t batch_count) {
+
+ // Determines the right kernel
+ auto kernel_name = std::string{};
+ if (do_transpose) {
+ kernel_name = (do_pad) ? "TransposePadMatrixStridedBatched" : "TransposeMatrixStridedBatched";
+ }
+ else {
+ kernel_name = (do_pad) ? "CopyPadMatrixStridedBatched" : "CopyMatrixStridedBatched";
+ }
+
+ // Retrieves the kernel from the compiled binary
+ auto kernel = Kernel(program, kernel_name);
+
+ // Sets the kernel arguments
+ kernel.SetArgument(0, static_cast<int>(src_one));
+ kernel.SetArgument(1, static_cast<int>(src_two));
+ kernel.SetArgument(2, static_cast<int>(src_ld));
+ kernel.SetArgument(3, static_cast<int>(src_offset));
+ kernel.SetArgument(4, static_cast<int>(src_stride));
+ kernel.SetArgument(5, src());
+ kernel.SetArgument(6, static_cast<int>(dest_one));
+ kernel.SetArgument(7, static_cast<int>(dest_two));
+ kernel.SetArgument(8, static_cast<int>(dest_ld));
+ kernel.SetArgument(9, static_cast<int>(dest_offset));
+ kernel.SetArgument(10, static_cast<int>(dest_stride));
+ kernel.SetArgument(11, dest());
+ if (do_pad) {
+ kernel.SetArgument(12, static_cast<int>(do_conjugate));
+ }
+
+ // Launches the kernel and returns the error code. Uses global and local thread sizes based on
+ // parameters in the database.
+ if (do_transpose) {
+ const auto global = std::vector<size_t>{
+ Ceil(CeilDiv(dest_one, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
+ Ceil(CeilDiv(dest_two, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
+ batch_count
+ };
+ const auto local = std::vector<size_t>{db["PADTRA_TILE"], db["PADTRA_TILE"], 1};
+ RunKernel(kernel, queue, device, global, local, event, waitForEvents);
+ }
+ else {
+ const auto global = std::vector<size_t>{
+ Ceil(CeilDiv(dest_one, db["PAD_WPTX"]), db["PAD_DIMX"]),
+ Ceil(CeilDiv(dest_two, db["PAD_WPTY"]), db["PAD_DIMY"]),
+ batch_count
+ };
+ const auto local = std::vector<size_t>{db["PAD_DIMX"], db["PAD_DIMY"], 1};
+ RunKernel(kernel, queue, device, global, local, event, waitForEvents);
+ }
+}
+
// =================================================================================================
} // namespace clblast