diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/clblast.cpp | 24 | ||||
-rw-r--r-- | src/routines/level3/xtrsm.cpp | 72 | ||||
-rw-r--r-- | src/routines/level3/xtrsm.hpp | 52 |
3 files changed, 141 insertions, 7 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index 4bb4e0b3..68671e50 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -66,6 +66,7 @@ #include "routines/level3/xsyr2k.hpp" #include "routines/level3/xher2k.hpp" #include "routines/level3/xtrmm.hpp" +#include "routines/level3/xtrsm.hpp" // Level-x includes (non-BLAS) #include "routines/levelx/xomatcopy.hpp" @@ -2067,13 +2068,22 @@ template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triang // Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM template <typename T> -StatusCode Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, - const size_t, const size_t, - const T, - const cl_mem, const size_t, const size_t, - cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*) { - return StatusCode::kNotImplemented; +StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event) { + try { + auto queue_cpp = Queue(*queue); + auto routine = Xtrsm<T>(queue_cpp, event); + routine.DoTrsm(layout, side, triangle, a_transpose, diagonal, + m, n, + alpha, + Buffer<T>(a_buffer), a_offset, a_ld, + Buffer<T>(b_buffer), b_offset, b_ld); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Trsm<float>(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp new file mode 100644 index 00000000..0ac1a58e --- /dev/null +++ b/src/routines/level3/xtrsm.cpp @@ -0,0 +1,72 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This file implements the Xtrsm class (see the header for information about the class). +// +// ================================================================================================= + +#include "routines/level3/xtrsm.hpp" + +#include <string> +#include <vector> + +namespace clblast { +// ================================================================================================= + +// Constructor: forwards to base class constructor +template <typename T> +Xtrsm<T>::Xtrsm(Queue &queue, EventPointer event, const std::string &name): + Xgemm<T>(queue, event, name) { +} + +// ================================================================================================= + +// The main routine +template <typename T> +void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle triangle, + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld) { + + // Makes sure all dimensions are larger than zero + if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); } + + // Computes the k dimension. This is based on whether or not matrix is A (on the left) + // or B (on the right) in the Xgemm routine. + auto k = (side == Side::kLeft) ? m : n; + + // Checks for validity of the triangular A matrix + TestMatrixA(k, k, a_buffer, a_offset, a_ld); + + // Checks for validity of the input/output B matrix + const auto b_one = (layout == Layout::kRowMajor) ? n : m; + const auto b_two = (layout == Layout::kRowMajor) ? m : n; + TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld); + + // Creates a copy of B to avoid overwriting input in GEMM while computing output + const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset); + auto b_buffer_copy = Buffer<T>(context_, b_size); + b_buffer.CopyTo(queue_, b_size, b_buffer_copy); + + // TODO: Implement TRSM computation +} + +// ================================================================================================= + +// Compiles the templated class +template class Xtrsm<half>; +template class Xtrsm<float>; +template class Xtrsm<double>; +template class Xtrsm<float2>; +template class Xtrsm<double2>; + +// ================================================================================================= +} // namespace clblast diff --git a/src/routines/level3/xtrsm.hpp b/src/routines/level3/xtrsm.hpp new file mode 100644 index 00000000..288e9d11 --- /dev/null +++ b/src/routines/level3/xtrsm.hpp @@ -0,0 +1,52 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This file implements the Xtrsm routine. The implementation is based on ??? (TODO). +// Therefore, this class inherits from the Xgemm class. +// +// ================================================================================================= + +#ifndef CLBLAST_ROUTINES_XTRSM_H_ +#define CLBLAST_ROUTINES_XTRSM_H_ + +#include "routines/level3/xgemm.hpp" + +namespace clblast { +// ================================================================================================= + +// See comment at top of file for a description of the class +template <typename T> +class Xtrsm: public Xgemm<T> { + public: + + // Uses methods and variables the Xgemm routine + using Xgemm<T>::routine_name_; + using Xgemm<T>::queue_; + using Xgemm<T>::context_; + using Xgemm<T>::device_; + using Xgemm<T>::db_; + using Xgemm<T>::DoGemm; + + // Constructor + Xtrsm(Queue &queue, EventPointer event, const std::string &name = "TRSM"); + + // Templated-precision implementation of the routine + void DoTrsm(const Layout layout, const Side side, const Triangle triangle, + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld); +}; + +// ================================================================================================= +} // namespace clblast + +// CLBLAST_ROUTINES_XTRSM_H_ +#endif |