diff options
author | CNugteren <web@cedricnugteren.nl> | 2015-06-30 07:36:11 +0200 |
---|---|---|
committer | CNugteren <web@cedricnugteren.nl> | 2015-06-30 07:36:11 +0200 |
commit | 8574f72d46f8f3572e2a5e9f24359d8da18ccf2a (patch) | |
tree | 2a0cba21da620ae0a37c121e6fb10a3aa18b715f | |
parent | a591d5607dcce7588718cb5f33176f3474e1f05d (diff) |
Added the TRMM and TRSM interface
-rw-r--r-- | include/clblast.h | 29 | ||||
-rw-r--r-- | src/clblast.cc | 142 |
2 files changed, 159 insertions, 12 deletions
diff --git a/include/clblast.h b/include/clblast.h index da504a0b..5da10810 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -75,6 +75,7 @@ enum class Layout { kRowMajor, kColMajor }; enum class Transpose { kNo, kYes, kConjugate }; enum class Side { kLeft, kRight }; enum class Triangle { kUpper, kLower }; +enum class Diagonal { kUnit, kNonUnit }; // Precision scoped enum (values in bits) enum class Precision { kHalf = 16, kSingle = 32, kDouble = 64, @@ -95,7 +96,7 @@ StatusCode Axpy(const size_t n, const T alpha, // Templated-precision generalized matrix-vector multiplication: SGEMV/DGEMV/CGEMV/ZGEMV template <typename T> -StatusCode Gemv(const Layout layout, const Transpose transpose_a, +StatusCode Gemv(const Layout layout, const Transpose a_transpose, const size_t m, const size_t n, const T alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, @@ -109,7 +110,7 @@ StatusCode Gemv(const Layout layout, const Transpose transpose_a, // Templated-precision generalized matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM template <typename T> -StatusCode Gemm(const Layout layout, const Transpose transpose_a, const Transpose transpose_b, +StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, const T alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, @@ -131,7 +132,7 @@ StatusCode Symm(const Layout layout, const Side side, const Triangle triangle, // Templated-precision rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK template <typename T> -StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose transpose_a, +StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, const T alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, @@ -141,7 +142,7 @@ StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose tr // Templated-precision rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K template <typename T> -StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose transpose_ab, +StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, const T alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, @@ -150,6 +151,26 @@ StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose t cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +// Templated-precision triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM +template <typename T> +StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event); + +// Templated-precision matrix equation solver: STRSM/DTRSM/CTRSM/ZTRSM +template <typename T> +StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event); + // ================================================================================================= } // namespace clblast diff --git a/src/clblast.cc b/src/clblast.cc index b8aa1e39..e3ce4d39 100644 --- a/src/clblast.cc +++ b/src/clblast.cc @@ -76,7 +76,7 @@ template StatusCode Axpy<double2>(const size_t, const double2, // GEMV template <typename T> -StatusCode Gemv(const Layout layout, const Transpose transpose_a, +StatusCode Gemv(const Layout layout, const Transpose a_transpose, const size_t m, const size_t n, const T alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, const T beta, @@ -94,7 +94,7 @@ StatusCode Gemv(const Layout layout, const Transpose transpose_a, if (status != StatusCode::kSuccess) { return status; } // Runs the routine - return routine.DoGemv(layout, transpose_a, m, n, alpha, + return routine.DoGemv(layout, a_transpose, m, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); @@ -129,7 +129,7 @@ template StatusCode Gemv<double2>(const Layout, const Transpose, // GEMM template <typename T> -StatusCode Gemm(const Layout layout, const Transpose transpose_a, const Transpose transpose_b, +StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, const T alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const T beta, @@ -155,7 +155,7 @@ StatusCode Gemm(const Layout layout, const Transpose transpose_a, const Transpos if (status != StatusCode::kSuccess) { return status; } // Runs the routine - return routine.DoGemm(layout, transpose_a, transpose_b, m, n, k, alpha, + return routine.DoGemm(layout, a_transpose, b_transpose, m, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, Buffer(c_buffer), c_offset, c_ld); @@ -249,7 +249,7 @@ template StatusCode Symm<double2>(const Layout, const Side, const Triangle, // SYRK template <typename T> -StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose transpose_a, +StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, const T alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const T beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, @@ -274,7 +274,7 @@ StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose tr if (status != StatusCode::kSuccess) { return status; } // Runs the routine - return routine.DoSyrk(layout, triangle, transpose_a, n, k, alpha, + return routine.DoSyrk(layout, triangle, a_transpose, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, beta, Buffer(c_buffer), c_offset, c_ld); } @@ -303,7 +303,7 @@ template StatusCode Syrk<double2>(const Layout, const Triangle, const Transpose, // SYR2K template <typename T> -StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose transpose_ab, +StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, const T alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const T beta, @@ -329,7 +329,7 @@ StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose t if (status != StatusCode::kSuccess) { return status; } // Runs the routine - return routine.DoSyr2k(layout, triangle, transpose_ab, n, k, alpha, + return routine.DoSyr2k(layout, triangle, ab_transpose, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, Buffer(c_buffer), c_offset, c_ld); @@ -360,4 +360,130 @@ template StatusCode Syr2k<double2>(const Layout, const Triangle, const Transpose cl_command_queue*, cl_event*); // ================================================================================================= + +// TRMM +template <typename T> +StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event) { + auto queue_cpp = CommandQueue(*queue); + auto event_cpp = Event(*event); + /* + auto routine = Xtrmm<T>(queue_cpp, event_cpp); + + // Loads the kernel source-code as an include (C++11 raw string literal) + std::string common_source1 = + #include "kernels/copy.opencl" + std::string common_source2 = + #include "kernels/pad.opencl" + std::string common_source3 = + #include "kernels/transpose.opencl" + std::string common_source4 = + #include "kernels/padtranspose.opencl" + std::string kernel_source = + #include "kernels/xgemm.opencl" + auto status = routine.SetUp(common_source1 + common_source2 + common_source3 + common_source4 + + kernel_source); + if (status != StatusCode::kSuccess) { return status; } + + // Runs the routine + return routine.DoTrmm(layout, side, triangle, a_transpose, diagonal, m, n, alpha, + Buffer(a_buffer), a_offset, a_ld, + Buffer(b_buffer), b_offset, b_ld); + */ + return StatusCode::kSuccess; +} +template StatusCode Trmm<float>(const Layout, const Side, const Triangle, + const Transpose, const Diagonal, + const size_t, const size_t, const float, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); +template StatusCode Trmm<double>(const Layout, const Side, const Triangle, + const Transpose, const Diagonal, + const size_t, const size_t, const double, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); +template StatusCode Trmm<float2>(const Layout, const Side, const Triangle, + const Transpose, const Diagonal, + const size_t, const size_t, const float2, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); +template StatusCode Trmm<double2>(const Layout, const Side, const Triangle, + const Transpose, const Diagonal, + const size_t, const size_t, const double2, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); + +// ================================================================================================= + +// TRSM +template <typename T> +StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, + const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const T alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event) { + auto queue_cpp = CommandQueue(*queue); + auto event_cpp = Event(*event); + /* + auto routine = Xtrsm<T>(queue_cpp, event_cpp); + + // Loads the kernel source-code as an include (C++11 raw string literal) + std::string common_source1 = + #include "kernels/copy.opencl" + std::string common_source2 = + #include "kernels/pad.opencl" + std::string common_source3 = + #include "kernels/transpose.opencl" + std::string common_source4 = + #include "kernels/padtranspose.opencl" + std::string kernel_source = + #include "kernels/xgemm.opencl" + auto status = routine.SetUp(common_source1 + common_source2 + common_source3 + common_source4 + + kernel_source); + if (status != StatusCode::kSuccess) { return status; } + + // Runs the routine + return routine.DoTrsm(layout, side, triangle, a_transpose, diagonal, m, n, alpha, + Buffer(a_buffer), a_offset, a_ld, + Buffer(b_buffer), b_offset, b_ld); + */ + return StatusCode::kSuccess; +} +template StatusCode Trsm<float>(const Layout, const Side, const Triangle, + const Transpose, const Diagonal, + const size_t, const size_t, const float, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); +template StatusCode Trsm<double>(const Layout, const Side, const Triangle, + const Transpose, const Diagonal, + const size_t, const size_t, const double, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); +template StatusCode Trsm<float2>(const Layout, const Side, const Triangle, + const Transpose, const Diagonal, + const size_t, const size_t, const float2, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); +template StatusCode Trsm<double2>(const Layout, const Side, const Triangle, + const Transpose, const Diagonal, + const size_t, const size_t, const double2, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); + +// ================================================================================================= } // namespace clblast |