From 49e04c7fce8fed45559e143137cef3a1a36328cc Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Fri, 10 Mar 2017 21:24:35 +0100 Subject: Added API and test infrastructure for the batched GEMM routine --- src/routines/levelx/xgemmbatched.cpp | 115 +++++++++++++++++++++++++++++++++++ src/routines/levelx/xgemmbatched.hpp | 47 ++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 src/routines/levelx/xgemmbatched.cpp create mode 100644 src/routines/levelx/xgemmbatched.hpp (limited to 'src/routines/levelx') diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp new file mode 100644 index 00000000..b07425d5 --- /dev/null +++ b/src/routines/levelx/xgemmbatched.cpp @@ -0,0 +1,115 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// This file implements the XgemmBatched class (see the header for information about the class). +// +// ================================================================================================= + +#include "routines/levelx/xgemmbatched.hpp" + +#include +#include + +namespace clblast { +// ================================================================================================= + +// Constructor: forwards to base class constructor +template +XgemmBatched::XgemmBatched(Queue &queue, EventPointer event, const std::string &name): + Routine(queue, event, name, + {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"}, + PrecisionValue(), {}, { + #include "../../kernels/level3/level3.opencl" + #include "../../kernels/level3/copy_fast.opencl" + #include "../../kernels/level3/copy_pad.opencl" + #include "../../kernels/level3/transpose_fast.opencl" + #include "../../kernels/level3/transpose_pad.opencl" + #include "../../kernels/level3/convert_symmetric.opencl" + #include "../../kernels/level3/convert_triangular.opencl" + #include "../../kernels/level3/convert_hermitian.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 + #include "../../kernels/level3/xgemm_direct_part1.opencl" + #include "../../kernels/level3/xgemm_direct_part2.opencl" + #include "../../kernels/level3/xgemm_direct_part3.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 + #include "../../kernels/level3/xgemm_part1.opencl" + #include "../../kernels/level3/xgemm_part2.opencl" + #include "../../kernels/level3/xgemm_part3.opencl" + }) { +} + +// ================================================================================================= + +// The main routine +template +void XgemmBatched::DoGemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const std::vector &alphas, + const Buffer & a_buffer, const std::vector &a_offsets, const size_t a_ld, + const Buffer & b_buffer, const std::vector &b_offsets, const size_t b_ld, + const std::vector &betas, + const Buffer & c_buffer, const std::vector &c_offsets, const size_t c_ld, + const size_t batch_count) { + + // Tests for a valid batch count + if ((batch_count < 1) || (alphas.size() != batch_count) || (betas.size() != batch_count) || + (a_offsets.size() != batch_count) || (b_offsets.size() != batch_count) || (c_offsets.size() != batch_count)) { + throw BLASError(StatusCode::kInvalidBatchCount); + } + + // Makes sure all dimensions are larger than zero + if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); } + + // Computes whether or not the matrices are transposed in memory. See GEMM routine for details. + const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) || + (layout == Layout::kRowMajor && a_transpose == Transpose::kNo); + const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) || + (layout == Layout::kRowMajor && b_transpose == Transpose::kNo); + const auto c_rotated = (layout == Layout::kRowMajor); + static const auto a_want_rotated = false; + static const auto b_want_rotated = true; + static const auto c_want_rotated = false; + const auto a_do_transpose = a_rotated != a_want_rotated; + const auto b_do_transpose = b_rotated != b_want_rotated; + const auto c_do_transpose = c_rotated != c_want_rotated; + + // In case of complex data-types, the transpose can also become a conjugate transpose + const auto a_conjugate = (a_transpose == Transpose::kConjugate); + const auto b_conjugate = (b_transpose == Transpose::kConjugate); + + // Computes the first and second dimensions of the 3 matrices taking into account whether the + // matrices are rotated or not + const auto a_one = (a_rotated) ? k : m; + const auto a_two = (a_rotated) ? m : k; + const auto b_one = (b_rotated) ? n : k; + const auto b_two = (b_rotated) ? k : n; + const auto c_one = (c_rotated) ? n : m; + const auto c_two = (c_rotated) ? m : n; + + // Tests the matrices for validity + for (auto batch = size_t{0}; batch < batch_count; ++batch) { + TestMatrixA(a_one, a_two, a_buffer, a_offsets[batch], a_ld); + TestMatrixB(b_one, b_two, b_buffer, b_offsets[batch], b_ld); + TestMatrixC(c_one, c_two, c_buffer, c_offsets[batch], c_ld); + } + + // StatusCode::kNotImplemented; +} + +// ================================================================================================= + +// Compiles the templated class +template class XgemmBatched; +template class XgemmBatched; +template class XgemmBatched; +template class XgemmBatched; +template class XgemmBatched; + +// ================================================================================================= +} // namespace clblast diff --git a/src/routines/levelx/xgemmbatched.hpp b/src/routines/levelx/xgemmbatched.hpp new file mode 100644 index 00000000..710011d8 --- /dev/null +++ b/src/routines/levelx/xgemmbatched.hpp @@ -0,0 +1,47 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// This file implements the XgemmBatched routine. This is a non-blas batched version of GEMM. +// +// ================================================================================================= + +#ifndef CLBLAST_ROUTINES_XGEMMBATCHED_H_ +#define CLBLAST_ROUTINES_XGEMMBATCHED_H_ + +#include + +#include "routine.hpp" + +namespace clblast { +// ================================================================================================= + +// See comment at top of file for a description of the class +template +class XgemmBatched: public Routine { + public: + + // Constructor + XgemmBatched(Queue &queue, EventPointer event, const std::string &name = "GEMMBATCHED"); + + // Templated-precision implementation of the routine + void DoGemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const std::vector &alphas, + const Buffer & a_buffer, const std::vector &a_offsets, const size_t a_ld, + const Buffer & b_buffer, const std::vector &b_offsets, const size_t b_ld, + const std::vector &betas, + const Buffer & c_buffer, const std::vector &c_offsets, const size_t c_ld, + const size_t batch_count); +}; + +// ================================================================================================= +} // namespace clblast + +// CLBLAST_ROUTINES_XGEMMBATCHED_H_ +#endif -- cgit v1.2.3