summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rw-r--r--doc/clblast.md71
-rwxr-xr-xscripts/generator/generator.py4
-rw-r--r--src/clblast.cpp24
-rw-r--r--src/routines/level3/xtrsm.cpp72
-rw-r--r--src/routines/level3/xtrsm.hpp52
-rw-r--r--test/routines/level3/xtrsm.hpp159
7 files changed, 374 insertions, 10 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index edb03dbf..f6fa4045 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -158,7 +158,7 @@ endif()
set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax)
set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv
xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2)
-set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm)
+set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm)
set(LEVELX_ROUTINES xomatcopy)
set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES})
set(PRECISIONS 32 64 3232 6464 16)
diff --git a/doc/clblast.md b/doc/clblast.md
index 37b99f3d..d7be0005 100644
--- a/doc/clblast.md
+++ b/doc/clblast.md
@@ -2708,6 +2708,77 @@ Requirements for TRMM:
+xTRSM: Solves a triangular system of equations
+-------------
+
+Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.
+
+C++ API:
+```
+template <typename T>
+StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event)
+```
+
+C API:
+```
+CLBlastStatusCode CLBlastStrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
+ const size_t m, const size_t n,
+ const float alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event)
+CLBlastStatusCode CLBlastDtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
+ const size_t m, const size_t n,
+ const double alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event)
+CLBlastStatusCode CLBlastCtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
+ const size_t m, const size_t n,
+ const cl_float2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event)
+CLBlastStatusCode CLBlastZtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
+ const size_t m, const size_t n,
+ const cl_double2 alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event)
+CLBlastStatusCode CLBlastHtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event)
+```
+
+Arguments to TRSM:
+
+* `const Layout layout`: Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.
+* `const Side side`: The position of the triangular matrix in the operation, either on the `Side::kLeft` (141) or `Side::kRight` (142).
+* `const Triangle triangle`: The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).
+* `const Transpose a_transpose`: Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
+* `const Diagonal diagonal`: The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.
+* `const size_t m`: Integer size argument. This value must be positive.
+* `const size_t n`: Integer size argument. This value must be positive.
+* `const T alpha`: Input scalar constant.
+* `const cl_mem a_buffer`: OpenCL buffer to store the input A matrix.
+* `const size_t a_offset`: The offset in elements from the start of the input A matrix.
+* `const size_t a_ld`: Leading dimension of the input A matrix. This value must be greater than 0.
+* `cl_mem b_buffer`: OpenCL buffer to store the output B matrix.
+* `const size_t b_offset`: The offset in elements from the start of the output B matrix.
+* `const size_t b_ld`: Leading dimension of the output B matrix. This value must be greater than 0.
+* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on.
+* `cl_event* event`: Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). This is an optional argument.
+
+
+
xOMATCOPY: Scaling and out-place transpose/copy (non-BLAS function)
-------------
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 35d902b7..d71e392d 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -41,7 +41,7 @@ FILES = [
"/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp",
]
-HEADER_LINES = [117, 73, 118, 22, 29, 41, 65, 32]
+HEADER_LINES = [117, 74, 118, 22, 29, 41, 65, 32]
FOOTER_LINES = [17, 80, 19, 18, 6, 6, 9, 2]
# Different possibilities for requirements
@@ -154,7 +154,7 @@ ROUTINES = [
Routine(True, True, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
- Routine(False, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "", []),
+ Routine(True, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []),
],
[ # Level X: extra routines (not part of BLAS)
Routine(True, True, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 4bb4e0b3..68671e50 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -66,6 +66,7 @@
#include "routines/level3/xsyr2k.hpp"
#include "routines/level3/xher2k.hpp"
#include "routines/level3/xtrmm.hpp"
+#include "routines/level3/xtrsm.hpp"
// Level-x includes (non-BLAS)
#include "routines/levelx/xomatcopy.hpp"
@@ -2067,13 +2068,22 @@ template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triang
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
template <typename T>
-StatusCode Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
- const size_t, const size_t,
- const T,
- const cl_mem, const size_t, const size_t,
- cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*) {
- return StatusCode::kNotImplemented;
+StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ auto queue_cpp = Queue(*queue);
+ auto routine = Xtrsm<T>(queue_cpp, event);
+ routine.DoTrsm(layout, side, triangle, a_transpose, diagonal,
+ m, n,
+ alpha,
+ Buffer<T>(a_buffer), a_offset, a_ld,
+ Buffer<T>(b_buffer), b_offset, b_ld);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API Trsm<float>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp
new file mode 100644
index 00000000..0ac1a58e
--- /dev/null
+++ b/src/routines/level3/xtrsm.cpp
@@ -0,0 +1,72 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Xtrsm class (see the header for information about the class).
+//
+// =================================================================================================
+
+#include "routines/level3/xtrsm.hpp"
+
+#include <string>
+#include <vector>
+
+namespace clblast {
+// =================================================================================================
+
+// Constructor: forwards to base class constructor
+template <typename T>
+Xtrsm<T>::Xtrsm(Queue &queue, EventPointer event, const std::string &name):
+ Xgemm<T>(queue, event, name) {
+}
+
+// =================================================================================================
+
+// The main routine
+template <typename T>
+void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld) {
+
+ // Makes sure all dimensions are larger than zero
+ if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
+
+ // Computes the k dimension. This is based on whether or not matrix is A (on the left)
+ // or B (on the right) in the Xgemm routine.
+ auto k = (side == Side::kLeft) ? m : n;
+
+ // Checks for validity of the triangular A matrix
+ TestMatrixA(k, k, a_buffer, a_offset, a_ld);
+
+ // Checks for validity of the input/output B matrix
+ const auto b_one = (layout == Layout::kRowMajor) ? n : m;
+ const auto b_two = (layout == Layout::kRowMajor) ? m : n;
+ TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld);
+
+ // Creates a copy of B to avoid overwriting input in GEMM while computing output
+ const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset);
+ auto b_buffer_copy = Buffer<T>(context_, b_size);
+ b_buffer.CopyTo(queue_, b_size, b_buffer_copy);
+
+ // TODO: Implement TRSM computation
+}
+
+// =================================================================================================
+
+// Compiles the templated class
+template class Xtrsm<half>;
+template class Xtrsm<float>;
+template class Xtrsm<double>;
+template class Xtrsm<float2>;
+template class Xtrsm<double2>;
+
+// =================================================================================================
+} // namespace clblast
diff --git a/src/routines/level3/xtrsm.hpp b/src/routines/level3/xtrsm.hpp
new file mode 100644
index 00000000..288e9d11
--- /dev/null
+++ b/src/routines/level3/xtrsm.hpp
@@ -0,0 +1,52 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the Xtrsm routine. The implementation is based on ??? (TODO).
+// Therefore, this class inherits from the Xgemm class.
+//
+// =================================================================================================
+
+#ifndef CLBLAST_ROUTINES_XTRSM_H_
+#define CLBLAST_ROUTINES_XTRSM_H_
+
+#include "routines/level3/xgemm.hpp"
+
+namespace clblast {
+// =================================================================================================
+
+// See comment at top of file for a description of the class
+template <typename T>
+class Xtrsm: public Xgemm<T> {
+ public:
+
+ // Uses methods and variables the Xgemm routine
+ using Xgemm<T>::routine_name_;
+ using Xgemm<T>::queue_;
+ using Xgemm<T>::context_;
+ using Xgemm<T>::device_;
+ using Xgemm<T>::db_;
+ using Xgemm<T>::DoGemm;
+
+ // Constructor
+ Xtrsm(Queue &queue, EventPointer event, const std::string &name = "TRSM");
+
+ // Templated-precision implementation of the routine
+ void DoTrsm(const Layout layout, const Side side, const Triangle triangle,
+ const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const T alpha,
+ const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
+ const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld);
+};
+
+// =================================================================================================
+} // namespace clblast
+
+// CLBLAST_ROUTINES_XTRSM_H_
+#endif
diff --git a/test/routines/level3/xtrsm.hpp b/test/routines/level3/xtrsm.hpp
new file mode 100644
index 00000000..2480a817
--- /dev/null
+++ b/test/routines/level3/xtrsm.hpp
@@ -0,0 +1,159 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements a class with static methods to describe the Xtrsm routine. Examples of
+// such 'descriptions' are how to calculate the size a of buffer or how to run the routine. These
+// static methods are used by the correctness tester and the performance tester.
+//
+// =================================================================================================
+
+#ifndef CLBLAST_TEST_ROUTINES_XTRSM_H_
+#define CLBLAST_TEST_ROUTINES_XTRSM_H_
+
+#include <vector>
+#include <string>
+
+#ifdef CLBLAST_REF_CLBLAS
+ #include "test/wrapper_clblas.hpp"
+#endif
+#ifdef CLBLAST_REF_CBLAS
+ #include "test/wrapper_cblas.hpp"
+#endif
+
+namespace clblast {
+// =================================================================================================
+
+// See comment at top of file for a description of the class
+template <typename T>
+class TestXtrsm {
+ public:
+
+ // The BLAS level: 1, 2, or 3
+ static size_t BLASLevel() { return 3; }
+
+ // The list of arguments relevant for this routine
+ static std::vector<std::string> GetOptions() {
+ return {kArgM, kArgN,
+ kArgLayout, kArgSide, kArgTriangle, kArgATransp, kArgDiagonal,
+ kArgALeadDim, kArgBLeadDim,
+ kArgAOffset, kArgBOffset,
+ kArgAlpha};
+ }
+
+ // Describes how to obtain the sizes of the buffers
+ static size_t GetSizeA(const Arguments<T> &args) {
+ auto k = (args.side == Side::kLeft) ? args.m : args.n;
+ return k * args.a_ld + args.a_offset;
+ }
+ static size_t GetSizeB(const Arguments<T> &args) {
+ auto b_rotated = (args.layout == Layout::kRowMajor);
+ auto b_two = (b_rotated) ? args.m : args.n;
+ return b_two * args.b_ld + args.b_offset;
+ }
+
+ // Describes how to set the sizes of all the buffers
+ static void SetSizes(Arguments<T> &args) {
+ args.a_size = GetSizeA(args);
+ args.b_size = GetSizeB(args);
+ }
+
+ // Describes what the default values of the leading dimensions of the matrices are
+ static size_t DefaultLDA(const Arguments<T> &args) { return args.m; }
+ static size_t DefaultLDB(const Arguments<T> &args) { return args.n; }
+ static size_t DefaultLDC(const Arguments<T> &) { return 1; } // N/A for this routine
+
+ // Describes which transpose options are relevant for this routine
+ using Transposes = std::vector<Transpose>;
+ static Transposes GetATransposes(const Transposes &all) { return all; }
+ static Transposes GetBTransposes(const Transposes &) { return {}; } // N/A for this routine
+
+ // Describes how to run the CLBlast routine
+ static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ auto queue_plain = queue();
+ auto event = cl_event{};
+ auto status = Trsm(args.layout, args.side, args.triangle, args.a_transpose, args.diagonal,
+ args.m, args.n, args.alpha,
+ buffers.a_mat(), args.a_offset, args.a_ld,
+ buffers.b_mat(), args.b_offset, args.b_ld,
+ &queue_plain, &event);
+ if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
+ return status;
+ }
+
+ // Describes how to run the clBLAS routine (for correctness/performance comparison)
+ #ifdef CLBLAST_REF_CLBLAS
+ static StatusCode RunReference1(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ auto queue_plain = queue();
+ auto event = cl_event{};
+ auto status = clblasXtrsm(convertToCLBLAS(args.layout),
+ convertToCLBLAS(args.side),
+ convertToCLBLAS(args.triangle),
+ convertToCLBLAS(args.a_transpose),
+ convertToCLBLAS(args.diagonal),
+ args.m, args.n, args.alpha,
+ buffers.a_mat, args.a_offset, args.a_ld,
+ buffers.b_mat, args.b_offset, args.b_ld,
+ 1, &queue_plain, 0, nullptr, &event);
+ clWaitForEvents(1, &event);
+ return static_cast<StatusCode>(status);
+ }
+ #endif
+
+ // Describes how to run the CPU BLAS routine (for correctness/performance comparison)
+ #ifdef CLBLAST_REF_CBLAS
+ static StatusCode RunReference2(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ std::vector<T> a_mat_cpu(args.a_size, static_cast<T>(0));
+ std::vector<T> b_mat_cpu(args.b_size, static_cast<T>(0));
+ buffers.a_mat.Read(queue, args.a_size, a_mat_cpu);
+ buffers.b_mat.Read(queue, args.b_size, b_mat_cpu);
+ cblasXtrsm(convertToCBLAS(args.layout),
+ convertToCBLAS(args.side),
+ convertToCBLAS(args.triangle),
+ convertToCBLAS(args.a_transpose),
+ convertToCBLAS(args.diagonal),
+ args.m, args.n, args.alpha,
+ a_mat_cpu, args.a_offset, args.a_ld,
+ b_mat_cpu, args.b_offset, args.b_ld);
+ buffers.b_mat.Write(queue, args.b_size, b_mat_cpu);
+ return StatusCode::kSuccess;
+ }
+ #endif
+
+ // Describes how to download the results of the computation (more importantly: which buffer)
+ static std::vector<T> DownloadResult(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ std::vector<T> result(args.b_size, static_cast<T>(0));
+ buffers.b_mat.Read(queue, args.b_size, result);
+ return result;
+ }
+
+ // Describes how to compute the indices of the result buffer
+ static size_t ResultID1(const Arguments<T> &args) { return args.m; }
+ static size_t ResultID2(const Arguments<T> &args) { return args.n; }
+ static size_t GetResultIndex(const Arguments<T> &args, const size_t id1, const size_t id2) {
+ return (args.layout == Layout::kRowMajor) ?
+ id1*args.b_ld + id2 + args.b_offset:
+ id2*args.b_ld + id1 + args.b_offset;
+ }
+
+ // Describes how to compute performance metrics
+ static size_t GetFlops(const Arguments<T> &args) {
+ auto k = (args.side == Side::kLeft) ? args.m : args.n;
+ return args.m * args.n * k;
+ }
+ static size_t GetBytes(const Arguments<T> &args) {
+ auto k = (args.side == Side::kLeft) ? args.m : args.n;
+ return (k*k + 2*args.m*args.n) * sizeof(T);
+ }
+};
+
+// =================================================================================================
+} // namespace clblast
+
+// CLBLAST_TEST_ROUTINES_XTRSM_H_
+#endif