summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCNugteren <web@cedricnugteren.nl>2015-06-30 07:36:11 +0200
committerCNugteren <web@cedricnugteren.nl>2015-06-30 07:36:11 +0200
commit8574f72d46f8f3572e2a5e9f24359d8da18ccf2a (patch)
tree2a0cba21da620ae0a37c121e6fb10a3aa18b715f
parenta591d5607dcce7588718cb5f33176f3474e1f05d (diff)
Added the TRMM and TRSM interface
-rw-r--r--include/clblast.h29
-rw-r--r--src/clblast.cc142
2 files changed, 159 insertions, 12 deletions
diff --git a/include/clblast.h b/include/clblast.h
index da504a0b..5da10810 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -75,6 +75,7 @@ enum class Layout { kRowMajor, kColMajor };
enum class Transpose { kNo, kYes, kConjugate };
enum class Side { kLeft, kRight };
enum class Triangle { kUpper, kLower };
+enum class Diagonal { kUnit, kNonUnit };
// Precision scoped enum (values in bits)
enum class Precision { kHalf = 16, kSingle = 32, kDouble = 64,
@@ -95,7 +96,7 @@ StatusCode Axpy(const size_t n, const T alpha,
// Templated-precision generalized matrix-vector multiplication: SGEMV/DGEMV/CGEMV/ZGEMV
template <typename T>
-StatusCode Gemv(const Layout layout, const Transpose transpose_a,
+StatusCode Gemv(const Layout layout, const Transpose a_transpose,
const size_t m, const size_t n,
const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
@@ -109,7 +110,7 @@ StatusCode Gemv(const Layout layout, const Transpose transpose_a,
// Templated-precision generalized matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM
template <typename T>
-StatusCode Gemm(const Layout layout, const Transpose transpose_a, const Transpose transpose_b,
+StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
@@ -131,7 +132,7 @@ StatusCode Symm(const Layout layout, const Side side, const Triangle triangle,
// Templated-precision rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK
template <typename T>
-StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose transpose_a,
+StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
const size_t n, const size_t k,
const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
@@ -141,7 +142,7 @@ StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose tr
// Templated-precision rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K
template <typename T>
-StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose transpose_ab,
+StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
const size_t n, const size_t k,
const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
@@ -150,6 +151,26 @@ StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose t
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event);
+// Templated-precision triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM
+template <typename T>
+StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event);
+
+// Templated-precision matrix equation solver: STRSM/DTRSM/CTRSM/ZTRSM
+template <typename T>
+StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event);
+
// =================================================================================================
} // namespace clblast
diff --git a/src/clblast.cc b/src/clblast.cc
index b8aa1e39..e3ce4d39 100644
--- a/src/clblast.cc
+++ b/src/clblast.cc
@@ -76,7 +76,7 @@ template StatusCode Axpy<double2>(const size_t, const double2,
// GEMV
template <typename T>
-StatusCode Gemv(const Layout layout, const Transpose transpose_a,
+StatusCode Gemv(const Layout layout, const Transpose a_transpose,
const size_t m, const size_t n, const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
const cl_mem x_buffer, const size_t x_offset, const size_t x_inc, const T beta,
@@ -94,7 +94,7 @@ StatusCode Gemv(const Layout layout, const Transpose transpose_a,
if (status != StatusCode::kSuccess) { return status; }
// Runs the routine
- return routine.DoGemv(layout, transpose_a, m, n, alpha,
+ return routine.DoGemv(layout, a_transpose, m, n, alpha,
Buffer(a_buffer), a_offset, a_ld,
Buffer(x_buffer), x_offset, x_inc, beta,
Buffer(y_buffer), y_offset, y_inc);
@@ -129,7 +129,7 @@ template StatusCode Gemv<double2>(const Layout, const Transpose,
// GEMM
template <typename T>
-StatusCode Gemm(const Layout layout, const Transpose transpose_a, const Transpose transpose_b,
+StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k, const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const T beta,
@@ -155,7 +155,7 @@ StatusCode Gemm(const Layout layout, const Transpose transpose_a, const Transpos
if (status != StatusCode::kSuccess) { return status; }
// Runs the routine
- return routine.DoGemm(layout, transpose_a, transpose_b, m, n, k, alpha,
+ return routine.DoGemm(layout, a_transpose, b_transpose, m, n, k, alpha,
Buffer(a_buffer), a_offset, a_ld,
Buffer(b_buffer), b_offset, b_ld, beta,
Buffer(c_buffer), c_offset, c_ld);
@@ -249,7 +249,7 @@ template StatusCode Symm<double2>(const Layout, const Side, const Triangle,
// SYRK
template <typename T>
-StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose transpose_a,
+StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
const size_t n, const size_t k, const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
@@ -274,7 +274,7 @@ StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose tr
if (status != StatusCode::kSuccess) { return status; }
// Runs the routine
- return routine.DoSyrk(layout, triangle, transpose_a, n, k, alpha,
+ return routine.DoSyrk(layout, triangle, a_transpose, n, k, alpha,
Buffer(a_buffer), a_offset, a_ld, beta,
Buffer(c_buffer), c_offset, c_ld);
}
@@ -303,7 +303,7 @@ template StatusCode Syrk<double2>(const Layout, const Triangle, const Transpose,
// SYR2K
template <typename T>
-StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose transpose_ab,
+StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
const size_t n, const size_t k, const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const T beta,
@@ -329,7 +329,7 @@ StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose t
if (status != StatusCode::kSuccess) { return status; }
// Runs the routine
- return routine.DoSyr2k(layout, triangle, transpose_ab, n, k, alpha,
+ return routine.DoSyr2k(layout, triangle, ab_transpose, n, k, alpha,
Buffer(a_buffer), a_offset, a_ld,
Buffer(b_buffer), b_offset, b_ld, beta,
Buffer(c_buffer), c_offset, c_ld);
@@ -360,4 +360,130 @@ template StatusCode Syr2k<double2>(const Layout, const Triangle, const Transpose
cl_command_queue*, cl_event*);
// =================================================================================================
+
+// TRMM
+template <typename T>
+StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event) {
+ auto queue_cpp = CommandQueue(*queue);
+ auto event_cpp = Event(*event);
+ /*
+ auto routine = Xtrmm<T>(queue_cpp, event_cpp);
+
+ // Loads the kernel source-code as an include (C++11 raw string literal)
+ std::string common_source1 =
+ #include "kernels/copy.opencl"
+ std::string common_source2 =
+ #include "kernels/pad.opencl"
+ std::string common_source3 =
+ #include "kernels/transpose.opencl"
+ std::string common_source4 =
+ #include "kernels/padtranspose.opencl"
+ std::string kernel_source =
+ #include "kernels/xgemm.opencl"
+ auto status = routine.SetUp(common_source1 + common_source2 + common_source3 + common_source4 +
+ kernel_source);
+ if (status != StatusCode::kSuccess) { return status; }
+
+ // Runs the routine
+ return routine.DoTrmm(layout, side, triangle, a_transpose, diagonal, m, n, alpha,
+ Buffer(a_buffer), a_offset, a_ld,
+ Buffer(b_buffer), b_offset, b_ld);
+ */
+ return StatusCode::kSuccess;
+}
+template StatusCode Trmm<float>(const Layout, const Side, const Triangle,
+ const Transpose, const Diagonal,
+ const size_t, const size_t, const float,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode Trmm<double>(const Layout, const Side, const Triangle,
+ const Transpose, const Diagonal,
+ const size_t, const size_t, const double,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode Trmm<float2>(const Layout, const Side, const Triangle,
+ const Transpose, const Diagonal,
+ const size_t, const size_t, const float2,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode Trmm<double2>(const Layout, const Side, const Triangle,
+ const Transpose, const Diagonal,
+ const size_t, const size_t, const double2,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
+
+// =================================================================================================
+
+// TRSM
+template <typename T>
+StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event) {
+ auto queue_cpp = CommandQueue(*queue);
+ auto event_cpp = Event(*event);
+ /*
+ auto routine = Xtrsm<T>(queue_cpp, event_cpp);
+
+ // Loads the kernel source-code as an include (C++11 raw string literal)
+ std::string common_source1 =
+ #include "kernels/copy.opencl"
+ std::string common_source2 =
+ #include "kernels/pad.opencl"
+ std::string common_source3 =
+ #include "kernels/transpose.opencl"
+ std::string common_source4 =
+ #include "kernels/padtranspose.opencl"
+ std::string kernel_source =
+ #include "kernels/xgemm.opencl"
+ auto status = routine.SetUp(common_source1 + common_source2 + common_source3 + common_source4 +
+ kernel_source);
+ if (status != StatusCode::kSuccess) { return status; }
+
+ // Runs the routine
+ return routine.DoTrsm(layout, side, triangle, a_transpose, diagonal, m, n, alpha,
+ Buffer(a_buffer), a_offset, a_ld,
+ Buffer(b_buffer), b_offset, b_ld);
+ */
+ return StatusCode::kSuccess;
+}
+template StatusCode Trsm<float>(const Layout, const Side, const Triangle,
+ const Transpose, const Diagonal,
+ const size_t, const size_t, const float,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode Trsm<double>(const Layout, const Side, const Triangle,
+ const Transpose, const Diagonal,
+ const size_t, const size_t, const double,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode Trsm<float2>(const Layout, const Side, const Triangle,
+ const Transpose, const Diagonal,
+ const size_t, const size_t, const float2,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode Trsm<double2>(const Layout, const Side, const Triangle,
+ const Transpose, const Diagonal,
+ const size_t, const size_t, const double2,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
+
+// =================================================================================================
} // namespace clblast