From 681a465b355d0fed7a82644f4472a5afa024721e Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 18 Dec 2016 12:30:16 +0100 Subject: Prepared for the addition of the TRSM triangular solver kernel --- src/routines/level3/xtrsm.cpp | 72 +++++++++++++++++++++++++++++++++++++++++++ src/routines/level3/xtrsm.hpp | 52 +++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 src/routines/level3/xtrsm.cpp create mode 100644 src/routines/level3/xtrsm.hpp (limited to 'src/routines') diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp new file mode 100644 index 00000000..0ac1a58e --- /dev/null +++ b/src/routines/level3/xtrsm.cpp @@ -0,0 +1,72 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// This file implements the Xtrsm class (see the header for information about the class). +// +// ================================================================================================= + +#include "routines/level3/xtrsm.hpp" + +#include +#include + +namespace clblast { +// ================================================================================================= + +// Constructor: forwards to base class constructor +template +Xtrsm::Xtrsm(Queue &queue, EventPointer event, const std::string &name): + Xgemm(queue, event, name) { +} + +// ================================================================================================= + +// The main routine +template +void Xtrsm::DoTrsm(const Layout layout, const Side side, const Triangle triangle, + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const Buffer &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer &b_buffer, const size_t b_offset, const size_t b_ld) { + + // Makes sure all dimensions are larger than zero + if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); } + + // Computes the k dimension. This is based on whether or not matrix is A (on the left) + // or B (on the right) in the Xgemm routine. + auto k = (side == Side::kLeft) ? m : n; + + // Checks for validity of the triangular A matrix + TestMatrixA(k, k, a_buffer, a_offset, a_ld); + + // Checks for validity of the input/output B matrix + const auto b_one = (layout == Layout::kRowMajor) ? n : m; + const auto b_two = (layout == Layout::kRowMajor) ? m : n; + TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld); + + // Creates a copy of B to avoid overwriting input in GEMM while computing output + const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset); + auto b_buffer_copy = Buffer(context_, b_size); + b_buffer.CopyTo(queue_, b_size, b_buffer_copy); + + // TODO: Implement TRSM computation +} + +// ================================================================================================= + +// Compiles the templated class +template class Xtrsm; +template class Xtrsm; +template class Xtrsm; +template class Xtrsm; +template class Xtrsm; + +// ================================================================================================= +} // namespace clblast diff --git a/src/routines/level3/xtrsm.hpp b/src/routines/level3/xtrsm.hpp new file mode 100644 index 00000000..288e9d11 --- /dev/null +++ b/src/routines/level3/xtrsm.hpp @@ -0,0 +1,52 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// This file implements the Xtrsm routine. The implementation is based on ??? (TODO). +// Therefore, this class inherits from the Xgemm class. +// +// ================================================================================================= + +#ifndef CLBLAST_ROUTINES_XTRSM_H_ +#define CLBLAST_ROUTINES_XTRSM_H_ + +#include "routines/level3/xgemm.hpp" + +namespace clblast { +// ================================================================================================= + +// See comment at top of file for a description of the class +template +class Xtrsm: public Xgemm { + public: + + // Uses methods and variables the Xgemm routine + using Xgemm::routine_name_; + using Xgemm::queue_; + using Xgemm::context_; + using Xgemm::device_; + using Xgemm::db_; + using Xgemm::DoGemm; + + // Constructor + Xtrsm(Queue &queue, EventPointer event, const std::string &name = "TRSM"); + + // Templated-precision implementation of the routine + void DoTrsm(const Layout layout, const Side side, const Triangle triangle, + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const Buffer &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer &b_buffer, const size_t b_offset, const size_t b_ld); +}; + +// ================================================================================================= +} // namespace clblast + +// CLBLAST_ROUTINES_XTRSM_H_ +#endif -- cgit v1.2.3