summaryrefslogtreecommitdiff
path: root/src/clblast.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-12-18 12:30:16 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2016-12-18 12:30:16 +0100
commit681a465b355d0fed7a82644f4472a5afa024721e (patch)
tree6a463b8a15090e09d8b0c72a54b470093428232b /src/clblast.cpp
parent6b533dda1ce8b4feda68708dec779ddc6200480c (diff)
Prepared for the addition of the TRSM triangular solver kernel
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r--src/clblast.cpp24
1 files changed, 17 insertions, 7 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 4bb4e0b3..68671e50 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -66,6 +66,7 @@
#include "routines/level3/xsyr2k.hpp"
#include "routines/level3/xher2k.hpp"
#include "routines/level3/xtrmm.hpp"
+#include "routines/level3/xtrsm.hpp"
// Level-x includes (non-BLAS)
#include "routines/levelx/xomatcopy.hpp"
@@ -2067,13 +2068,22 @@ template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triang
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
template <typename T>
-StatusCode Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
- const size_t, const size_t,
- const T,
- const cl_mem, const size_t, const size_t,
- cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*) {
- return StatusCode::kNotImplemented;
+StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ auto queue_cpp = Queue(*queue);
+ auto routine = Xtrsm<T>(queue_cpp, event);
+ routine.DoTrsm(layout, side, triangle, a_transpose, diagonal,
+ m, n,
+ alpha,
+ Buffer<T>(a_buffer), a_offset, a_ld,
+ Buffer<T>(b_buffer), b_offset, b_ld);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API Trsm<float>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,