summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/clblast.cpp24
-rw-r--r--src/routines/level3/xtrsm.cpp72
-rw-r--r--src/routines/level3/xtrsm.hpp52
3 files changed, 141 insertions, 7 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 4bb4e0b3..68671e50 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -66,6 +66,7 @@
#include "routines/level3/xsyr2k.hpp"
#include "routines/level3/xher2k.hpp"
#include "routines/level3/xtrmm.hpp"
+#include "routines/level3/xtrsm.hpp"
// Level-x includes (non-BLAS)
#include "routines/levelx/xomatcopy.hpp"
@@ -2067,13 +2068,22 @@ template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triang
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
template <typename T>
-StatusCode Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
- const size_t, const size_t,
- const T,
- const cl_mem, const size_t, const size_t,
- cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*) {
- return StatusCode::kNotImplemented;
+StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ auto queue_cpp = Queue(*queue);
+ auto routine = Xtrsm<T>(queue_cpp, event);
+ routine.DoTrsm(layout, side, triangle, a_transpose, diagonal,
+ m, n,
+ alpha,
+ Buffer<T>(a_buffer), a_offset, a_ld,
+ Buffer<T>(b_buffer), b_offset, b_ld);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API Trsm<float>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp
new file mode 100644
index 00000000..0ac1a58e
--- /dev/null
+++ b/src/routines/level3/xtrsm.cpp
@@ -0,0 +1,72 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Xtrsm class (see the header for information about the class).
+//
+// =================================================================================================
+
+#include "routines/level3/xtrsm.hpp"
+
+#include <string>
+#include <vector>
+
+namespace clblast {
+// =================================================================================================
+
+// Constructor: forwards to base class constructor
+template <typename T>
+Xtrsm<T>::Xtrsm(Queue &queue, EventPointer event, const std::string &name):
+ Xgemm<T>(queue, event, name) {
+}
+
+// =================================================================================================
+
+// The main routine
+template <typename T>
+void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld) {
+
+ // Makes sure all dimensions are larger than zero
+ if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
+
+ // Computes the k dimension. This is based on whether or not matrix is A (on the left)
+ // or B (on the right) in the Xgemm routine.
+ auto k = (side == Side::kLeft) ? m : n;
+
+ // Checks for validity of the triangular A matrix
+ TestMatrixA(k, k, a_buffer, a_offset, a_ld);
+
+ // Checks for validity of the input/output B matrix
+ const auto b_one = (layout == Layout::kRowMajor) ? n : m;
+ const auto b_two = (layout == Layout::kRowMajor) ? m : n;
+ TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld);
+
+ // Creates a copy of B to avoid overwriting input in GEMM while computing output
+ const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset);
+ auto b_buffer_copy = Buffer<T>(context_, b_size);
+ b_buffer.CopyTo(queue_, b_size, b_buffer_copy);
+
+ // TODO: Implement TRSM computation
+}
+
+// =================================================================================================
+
+// Compiles the templated class
+template class Xtrsm<half>;
+template class Xtrsm<float>;
+template class Xtrsm<double>;
+template class Xtrsm<float2>;
+template class Xtrsm<double2>;
+
+// =================================================================================================
+} // namespace clblast
diff --git a/src/routines/level3/xtrsm.hpp b/src/routines/level3/xtrsm.hpp
new file mode 100644
index 00000000..288e9d11
--- /dev/null
+++ b/src/routines/level3/xtrsm.hpp
@@ -0,0 +1,52 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Xtrsm routine. The implementation is based on ??? (TODO).
+// Therefore, this class inherits from the Xgemm class.
+//
+// =================================================================================================
+
+#ifndef CLBLAST_ROUTINES_XTRSM_H_
+#define CLBLAST_ROUTINES_XTRSM_H_
+
+#include "routines/level3/xgemm.hpp"
+
+namespace clblast {
+// =================================================================================================
+
+// See comment at top of file for a description of the class
+template <typename T>
+class Xtrsm: public Xgemm<T> {
+ public:
+
+ // Uses methods and variables the Xgemm routine
+ using Xgemm<T>::routine_name_;
+ using Xgemm<T>::queue_;
+ using Xgemm<T>::context_;
+ using Xgemm<T>::device_;
+ using Xgemm<T>::db_;
+ using Xgemm<T>::DoGemm;
+
+ // Constructor
+ Xtrsm(Queue &queue, EventPointer event, const std::string &name = "TRSM");
+
+ // Templated-precision implementation of the routine
+ void DoTrsm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld);
+};
+
+// =================================================================================================
+} // namespace clblast
+
+// CLBLAST_ROUTINES_XTRSM_H_
+#endif