summaryrefslogtreecommitdiff
path: root/src/routines
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-12-18 12:30:16 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2016-12-18 12:30:16 +0100
commit681a465b355d0fed7a82644f4472a5afa024721e (patch)
tree6a463b8a15090e09d8b0c72a54b470093428232b /src/routines
parent6b533dda1ce8b4feda68708dec779ddc6200480c (diff)
Prepared for the addition of the TRSM triangular solver kernel
Diffstat (limited to 'src/routines')
-rw-r--r--src/routines/level3/xtrsm.cpp72
-rw-r--r--src/routines/level3/xtrsm.hpp52
2 files changed, 124 insertions, 0 deletions
diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp
new file mode 100644
index 00000000..0ac1a58e
--- /dev/null
+++ b/src/routines/level3/xtrsm.cpp
@@ -0,0 +1,72 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Xtrsm class (see the header for information about the class).
+//
+// =================================================================================================
+
+#include "routines/level3/xtrsm.hpp"
+
+#include <string>
+#include <vector>
+
+namespace clblast {
+// =================================================================================================
+
+// Constructor: forwards to base class constructor
+template <typename T>
+Xtrsm<T>::Xtrsm(Queue &queue, EventPointer event, const std::string &name):
+ Xgemm<T>(queue, event, name) {
+}
+
+// =================================================================================================
+
+// The main routine
+template <typename T>
+void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld) {
+
+ // Makes sure all dimensions are larger than zero
+ if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
+
+ // Computes the k dimension. This is based on whether or not matrix is A (on the left)
+ // or B (on the right) in the Xgemm routine.
+ auto k = (side == Side::kLeft) ? m : n;
+
+ // Checks for validity of the triangular A matrix
+ TestMatrixA(k, k, a_buffer, a_offset, a_ld);
+
+ // Checks for validity of the input/output B matrix
+ const auto b_one = (layout == Layout::kRowMajor) ? n : m;
+ const auto b_two = (layout == Layout::kRowMajor) ? m : n;
+ TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld);
+
+ // Creates a copy of B to avoid overwriting input in GEMM while computing output
+ const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset);
+ auto b_buffer_copy = Buffer<T>(context_, b_size);
+ b_buffer.CopyTo(queue_, b_size, b_buffer_copy);
+
+ // TODO: Implement TRSM computation
+}
+
+// =================================================================================================
+
+// Compiles the templated class
+template class Xtrsm<half>;
+template class Xtrsm<float>;
+template class Xtrsm<double>;
+template class Xtrsm<float2>;
+template class Xtrsm<double2>;
+
+// =================================================================================================
+} // namespace clblast
diff --git a/src/routines/level3/xtrsm.hpp b/src/routines/level3/xtrsm.hpp
new file mode 100644
index 00000000..288e9d11
--- /dev/null
+++ b/src/routines/level3/xtrsm.hpp
@@ -0,0 +1,52 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Xtrsm routine. The implementation is based on ??? (TODO).
+// Therefore, this class inherits from the Xgemm class.
+//
+// =================================================================================================
+
+#ifndef CLBLAST_ROUTINES_XTRSM_H_
+#define CLBLAST_ROUTINES_XTRSM_H_
+
+#include "routines/level3/xgemm.hpp"
+
+namespace clblast {
+// =================================================================================================
+
+// See comment at top of file for a description of the class
+template <typename T>
+class Xtrsm: public Xgemm<T> {
+ public:
+
+ // Uses methods and variables the Xgemm routine
+ using Xgemm<T>::routine_name_;
+ using Xgemm<T>::queue_;
+ using Xgemm<T>::context_;
+ using Xgemm<T>::device_;
+ using Xgemm<T>::db_;
+ using Xgemm<T>::DoGemm;
+
+ // Constructor
+ Xtrsm(Queue &queue, EventPointer event, const std::string &name = "TRSM");
+
+ // Templated-precision implementation of the routine
+ void DoTrsm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld);
+};
+
+// =================================================================================================
+} // namespace clblast
+
+// CLBLAST_ROUTINES_XTRSM_H_
+#endif