summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rw-r--r--doc/clblast.md99
-rw-r--r--include/clblast.h12
-rw-r--r--include/clblast_c.h47
-rwxr-xr-xscripts/generator/generator.py3
-rw-r--r--src/clblast.cpp84
-rw-r--r--src/clblast_c.cpp157
-rw-r--r--src/routines/levelx/xgemmbatched.cpp115
-rw-r--r--src/routines/levelx/xgemmbatched.hpp47
-rw-r--r--test/correctness/routines/levelx/xgemmbatched.cpp30
-rw-r--r--test/performance/routines/levelx/xgemmbatched.cpp37
-rw-r--r--test/routines/levelx/xgemmbatched.hpp207
12 files changed, 838 insertions, 2 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ef6156dd..62cf00cc 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -159,7 +159,7 @@ set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax)
set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv
xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2)
set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm)
-set(LEVELX_ROUTINES xomatcopy xaxpybatched)
+set(LEVELX_ROUTINES xomatcopy xaxpybatched xgemmbatched)
set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES})
set(PRECISIONS 32 64 3232 6464 16)
diff --git a/doc/clblast.md b/doc/clblast.md
index 120c0c2c..6ff5f7d0 100644
--- a/doc/clblast.md
+++ b/doc/clblast.md
@@ -2969,6 +2969,105 @@ Arguments to AXPYBATCHED:
+xGEMMBATCHED: Batched version of GEMM
+-------------
+
+As GEMM, but multiple operations are batched together for better performance.
+
+C++ API:
+```
+template <typename T>
+StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const T *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const T *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event)
+```
+
+C API:
+```
+CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const float *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const float *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event)
+CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const double *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const double *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event)
+CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_float2 *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_float2 *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event)
+CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_double2 *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_double2 *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event)
+CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_half *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event)
+```
+
+Arguments to GEMMBATCHED:
+
+* `const Layout layout`: Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.
+* `const Transpose a_transpose`: Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
+* `const Transpose b_transpose`: Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
+* `const size_t m`: Integer size argument. This value must be positive.
+* `const size_t n`: Integer size argument. This value must be positive.
+* `const size_t k`: Integer size argument. This value must be positive.
+* `const T *alphas`: Input scalar constants.
+* `const cl_mem a_buffer`: OpenCL buffer to store the input A matrix.
+* `const size_t *a_offsets`: The offsets in elements from the start of the input A matrix.
+* `const size_t a_ld`: Leading dimension of the input A matrix. This value must be greater than 0.
+* `const cl_mem b_buffer`: OpenCL buffer to store the input B matrix.
+* `const size_t *b_offsets`: The offsets in elements from the start of the input B matrix.
+* `const size_t b_ld`: Leading dimension of the input B matrix. This value must be greater than 0.
+* `const T *betas`: Input scalar constants.
+* `cl_mem c_buffer`: OpenCL buffer to store the output C matrix.
+* `const size_t *c_offsets`: The offsets in elements from the start of the output C matrix.
+* `const size_t c_ld`: Leading dimension of the output C matrix. This value must be greater than 0.
+* `const size_t batch_count`: Number of batches. This value must be positive.
+* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on.
+* `cl_event* event`: Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). This is an optional argument.
+
+Requirements for GEMMBATCHED:
+
+* When `transpose_a == Transpose::kNo`, then `a_ld` must be at least `m`, otherwise `a_ld` must be at least `k`.
+* When `transpose_b == Transpose::kNo`, then `b_ld` must be at least `k`, otherwise `b_ld` must be at least `n`.
+* The value of `c_ld` must be at least `m`.
+
+
+
ClearCache: Resets the cache of compiled binaries (auxiliary function)
-------------
diff --git a/include/clblast.h b/include/clblast.h
index a1f14471..2520d601 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -619,6 +619,18 @@ StatusCode AxpyBatched(const size_t n,
const size_t batch_count,
cl_command_queue* queue, cl_event* event = nullptr);
+// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
+template <typename T>
+StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const T *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const T *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event = nullptr);
+
// =================================================================================================
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
diff --git a/include/clblast_c.h b/include/clblast_c.h
index 4f21ba17..b0ef5f34 100644
--- a/include/clblast_c.h
+++ b/include/clblast_c.h
@@ -1360,6 +1360,53 @@ CLBlastStatusCode PUBLIC_API CLBlastHaxpyBatched(const size_t n,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
+// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
+CLBlastStatusCode PUBLIC_API CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const float *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const float *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event);
+CLBlastStatusCode PUBLIC_API CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const double *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const double *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event);
+CLBlastStatusCode PUBLIC_API CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_float2 *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_float2 *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event);
+CLBlastStatusCode PUBLIC_API CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_double2 *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_double2 *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event);
+CLBlastStatusCode PUBLIC_API CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_half *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event);
+
// =================================================================================================
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 8dd5fc0c..086b27d3 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -41,7 +41,7 @@ FILES = [
"/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp",
]
-HEADER_LINES = [122, 76, 126, 23, 29, 41, 65, 32]
+HEADER_LINES = [123, 76, 126, 23, 29, 41, 65, 32]
FOOTER_LINES = [25, 138, 27, 38, 6, 6, 9, 2]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
@@ -163,6 +163,7 @@ ROUTINES = [
Routine(True, True, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
# Batched routines:
Routine(True, True, True, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
+ Routine(True, True, True, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
]]
diff --git a/src/clblast.cpp b/src/clblast.cpp
index d3db8edf..a8bcf91d 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -72,6 +72,7 @@
// Level-x includes (non-BLAS)
#include "routines/levelx/xomatcopy.hpp"
#include "routines/levelx/xaxpybatched.hpp"
+#include "routines/levelx/xgemmbatched.hpp"
namespace clblast {
@@ -2231,6 +2232,89 @@ template StatusCode PUBLIC_API AxpyBatched<half>(const size_t,
cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
+
+// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
+template <typename T>
+StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const T *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const T *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ try {
+ auto queue_cpp = Queue(*queue);
+ auto routine = XgemmBatched<T>(queue_cpp, event);
+ auto alphas_cpp = std::vector<T>();
+ auto betas_cpp = std::vector<T>();
+ auto a_offsets_cpp = std::vector<size_t>();
+ auto b_offsets_cpp = std::vector<size_t>();
+ auto c_offsets_cpp = std::vector<size_t>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ betas_cpp.push_back(betas[batch]);
+ a_offsets_cpp.push_back(a_offsets[batch]);
+ b_offsets_cpp.push_back(b_offsets[batch]);
+ c_offsets_cpp.push_back(c_offsets[batch]);
+ }
+ routine.DoGemmBatched(layout, a_transpose, b_transpose,
+ m, n, k,
+ alphas_cpp,
+ Buffer<T>(a_buffer), a_offsets_cpp, a_ld,
+ Buffer<T>(b_buffer), b_offsets_cpp, b_ld,
+ betas_cpp,
+ Buffer<T>(c_buffer), c_offsets_cpp, c_ld,
+ batch_count);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmBatched<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const float*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const float*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API GemmBatched<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const double*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const double*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API GemmBatched<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const float2*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const float2*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API GemmBatched<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const double2*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const double2*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const half*,
+ const cl_mem, const size_t*, const size_t,
+ const cl_mem, const size_t*, const size_t,
+ const half*,
+ cl_mem, const size_t*, const size_t,
+ const size_t,
+ cl_command_queue*, cl_event*);
// =================================================================================================
// Clears the cache of stored binaries
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index b09f8c54..f5104bad 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -3554,6 +3554,163 @@ CLBlastStatusCode CLBlastHaxpyBatched(const size_t n,
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
+// GEMM
+CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const float *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const float *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<float>();
+ auto betas_cpp = std::vector<float>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ betas_cpp.push_back(betas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const double *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const double *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<double>();
+ auto betas_cpp = std::vector<double>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ betas_cpp.push_back(betas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_float2 *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_float2 *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<float2>();
+ auto betas_cpp = std::vector<float2>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(float2{alphas[batch].s[0], alphas[batch].s[1]});
+ betas_cpp.push_back(float2{betas[batch].s[0], betas[batch].s[1]});
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_double2 *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_double2 *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<double2>();
+ auto betas_cpp = std::vector<double2>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(double2{alphas[batch].s[0], alphas[batch].s[1]});
+ betas_cpp.push_back(double2{betas[batch].s[0], betas[batch].s[1]});
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half *alphas,
+ const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
+ const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
+ const cl_half *betas,
+ cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
+ const size_t batch_count,
+ cl_command_queue* queue, cl_event* event) {
+ auto alphas_cpp = std::vector<half>();
+ auto betas_cpp = std::vector<half>();
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ alphas_cpp.push_back(alphas[batch]);
+ betas_cpp.push_back(betas[batch]);
+ }
+ try {
+ return static_cast<CLBlastStatusCode>(
+ clblast::GemmBatched(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alphas_cpp.data(),
+ a_buffer, a_offsets, a_ld,
+ b_buffer, b_offsets, b_ld,
+ betas_cpp.data(),
+ c_buffer, c_offsets, c_ld,
+ batch_count,
+ queue, event)
+ );
+ } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
+}
+
// =================================================================================================
// Clears the cache of stored binaries
diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp
new file mode 100644
index 00000000..b07425d5
--- /dev/null
+++ b/src/routines/levelx/xgemmbatched.cpp
@@ -0,0 +1,115 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the XgemmBatched class (see the header for information about the class).
+//
+// =================================================================================================
+
+#include "routines/levelx/xgemmbatched.hpp"
+
+#include <string>
+#include <vector>
+
+namespace clblast {
+// =================================================================================================
+
+// Constructor: forwards to base class constructor
+template <typename T>
+XgemmBatched<T>::XgemmBatched(Queue &queue, EventPointer event, const std::string &name):
+ Routine(queue, event, name,
+ {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"},
+ PrecisionValue<T>(), {}, {
+ #include "../../kernels/level3/level3.opencl"
+ #include "../../kernels/level3/copy_fast.opencl"
+ #include "../../kernels/level3/copy_pad.opencl"
+ #include "../../kernels/level3/transpose_fast.opencl"
+ #include "../../kernels/level3/transpose_pad.opencl"
+ #include "../../kernels/level3/convert_symmetric.opencl"
+ #include "../../kernels/level3/convert_triangular.opencl"
+ #include "../../kernels/level3/convert_hermitian.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
+ #include "../../kernels/level3/xgemm_direct_part1.opencl"
+ #include "../../kernels/level3/xgemm_direct_part2.opencl"
+ #include "../../kernels/level3/xgemm_direct_part3.opencl"
+ , // separated in multiple parts to prevent C1091 in MSVC 2013
+ #include "../../kernels/level3/xgemm_part1.opencl"
+ #include "../../kernels/level3/xgemm_part2.opencl"
+ #include "../../kernels/level3/xgemm_part3.opencl"
+ }) {
+}
+
+// =================================================================================================
+
+// The main routine
+template <typename T>
+void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const std::vector<T> &alphas,
+ const Buffer<T> & a_buffer, const std::vector<size_t> &a_offsets, const size_t a_ld,
+ const Buffer<T> & b_buffer, const std::vector<size_t> &b_offsets, const size_t b_ld,
+ const std::vector<T> &betas,
+ const Buffer<T> & c_buffer, const std::vector<size_t> &c_offsets, const size_t c_ld,
+ const size_t batch_count) {
+
+ // Tests for a valid batch count
+ if ((batch_count < 1) || (alphas.size() != batch_count) || (betas.size() != batch_count) ||
+ (a_offsets.size() != batch_count) || (b_offsets.size() != batch_count) || (c_offsets.size() != batch_count)) {
+ throw BLASError(StatusCode::kInvalidBatchCount);
+ }
+
+ // Makes sure all dimensions are larger than zero
+ if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
+
+ // Computes whether or not the matrices are transposed in memory. See GEMM routine for details.
+ const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
+ (layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
+ const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
+ (layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
+ const auto c_rotated = (layout == Layout::kRowMajor);
+ static const auto a_want_rotated = false;
+ static const auto b_want_rotated = true;
+ static const auto c_want_rotated = false;
+ const auto a_do_transpose = a_rotated != a_want_rotated;
+ const auto b_do_transpose = b_rotated != b_want_rotated;
+ const auto c_do_transpose = c_rotated != c_want_rotated;
+
+ // In case of complex data-types, the transpose can also become a conjugate transpose
+ const auto a_conjugate = (a_transpose == Transpose::kConjugate);
+ const auto b_conjugate = (b_transpose == Transpose::kConjugate);
+
+ // Computes the first and second dimensions of the 3 matrices taking into account whether the
+ // matrices are rotated or not
+ const auto a_one = (a_rotated) ? k : m;
+ const auto a_two = (a_rotated) ? m : k;
+ const auto b_one = (b_rotated) ? n : k;
+ const auto b_two = (b_rotated) ? k : n;
+ const auto c_one = (c_rotated) ? n : m;
+ const auto c_two = (c_rotated) ? m : n;
+
+ // Tests the matrices for validity
+ for (auto batch = size_t{0}; batch < batch_count; ++batch) {
+ TestMatrixA(a_one, a_two, a_buffer, a_offsets[batch], a_ld);
+ TestMatrixB(b_one, b_two, b_buffer, b_offsets[batch], b_ld);
+ TestMatrixC(c_one, c_two, c_buffer, c_offsets[batch], c_ld);
+ }
+
+ // StatusCode::kNotImplemented;
+}
+
+// =================================================================================================
+
+// Compiles the templated class
+template class XgemmBatched<half>;
+template class XgemmBatched<float>;
+template class XgemmBatched<double>;
+template class XgemmBatched<float2>;
+template class XgemmBatched<double2>;
+
+// =================================================================================================
+} // namespace clblast
diff --git a/src/routines/levelx/xgemmbatched.hpp b/src/routines/levelx/xgemmbatched.hpp
new file mode 100644
index 00000000..710011d8
--- /dev/null
+++ b/src/routines/levelx/xgemmbatched.hpp
@@ -0,0 +1,47 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the XgemmBatched routine. This is a non-blas batched version of GEMM.
+//
+// =================================================================================================
+
+#ifndef CLBLAST_ROUTINES_XGEMMBATCHED_H_
+#define CLBLAST_ROUTINES_XGEMMBATCHED_H_
+
+#include <vector>
+
+#include "routine.hpp"
+
+namespace clblast {
+// =================================================================================================
+
+// See comment at top of file for a description of the class
+template <typename T>
+class XgemmBatched: public Routine {
+ public:
+
+ // Constructor
+ XgemmBatched(Queue &queue, EventPointer event, const std::string &name = "GEMMBATCHED");
+
+ // Templated-precision implementation of the routine
+ void DoGemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const std::vector<T> &alphas,
+ const Buffer<T> & a_buffer, const std::vector<size_t> &a_offsets, const size_t a_ld,
+ const Buffer<T> & b_buffer, const std::vector<size_t> &b_offsets, const size_t b_ld,
+ const std::vector<T> &betas,
+ const Buffer<T> & c_buffer, const std::vector<size_t> &c_offsets, const size_t c_ld,
+ const size_t batch_count);
+};
+
+// =================================================================================================
+} // namespace clblast
+
+// CLBLAST_ROUTINES_XGEMMBATCHED_H_
+#endif
diff --git a/test/correctness/routines/levelx/xgemmbatched.cpp b/test/correctness/routines/levelx/xgemmbatched.cpp
new file mode 100644
index 00000000..748e1bb7
--- /dev/null
+++ b/test/correctness/routines/levelx/xgemmbatched.cpp
@@ -0,0 +1,30 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// =================================================================================================
+
+#include "test/correctness/testblas.hpp"
+#include "test/routines/levelx/xgemmbatched.hpp"
+
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
+// Main function (not within the clblast namespace)
+int main(int argc, char *argv[]) {
+ auto errors = size_t{0};
+ errors += clblast::RunTests<clblast::TestXgemmBatched<float>, float, float>(argc, argv, false, "SGEMMBATCHED");
+ errors += clblast::RunTests<clblast::TestXgemmBatched<double>, double, double>(argc, argv, true, "DGEMMBATCHED");
+ errors += clblast::RunTests<clblast::TestXgemmBatched<float2>, float2, float2>(argc, argv, true, "CGEMMBATCHED");
+ errors += clblast::RunTests<clblast::TestXgemmBatched<double2>, double2, double2>(argc, argv, true, "ZGEMMBATCHED");
+ errors += clblast::RunTests<clblast::TestXgemmBatched<half>, half, half>(argc, argv, true, "HGEMMBATCHED");
+ if (errors > 0) { return 1; } else { return 0; }
+}
+
+// =================================================================================================
diff --git a/test/performance/routines/levelx/xgemmbatched.cpp b/test/performance/routines/levelx/xgemmbatched.cpp
new file mode 100644
index 00000000..c9477fad
--- /dev/null
+++ b/test/performance/routines/levelx/xgemmbatched.cpp
@@ -0,0 +1,37 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// =================================================================================================
+
+#include "test/performance/client.hpp"
+#include "test/routines/levelx/xgemmbatched.hpp"
+
+// Shortcuts to the clblast namespace
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
+// Main function (not within the clblast namespace)
+int main(int argc, char *argv[]) {
+ const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
+ switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) {
+ case clblast::Precision::kHalf:
+ clblast::RunClient<clblast::TestXgemmBatched<half>, half, half>(argc, argv); break;
+ case clblast::Precision::kSingle:
+ clblast::RunClient<clblast::TestXgemmBatched<float>, float, float>(argc, argv); break;
+ case clblast::Precision::kDouble:
+ clblast::RunClient<clblast::TestXgemmBatched<double>, double, double>(argc, argv); break;
+ case clblast::Precision::kComplexSingle:
+ clblast::RunClient<clblast::TestXgemmBatched<float2>, float2, float2>(argc, argv); break;
+ case clblast::Precision::kComplexDouble:
+ clblast::RunClient<clblast::TestXgemmBatched<double2>, double2, double2>(argc, argv); break;
+ }
+ return 0;
+}
+
+// =================================================================================================
diff --git a/test/routines/levelx/xgemmbatched.hpp b/test/routines/levelx/xgemmbatched.hpp
new file mode 100644
index 00000000..80a30e4d
--- /dev/null
+++ b/test/routines/levelx/xgemmbatched.hpp
@@ -0,0 +1,207 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements a class with static methods to describe the XgemmBatched routine. Examples of
+// such 'descriptions' are how to calculate the size a of buffer or how to run the routine. These
+// static methods are used by the correctness tester and the performance tester.
+//
+// =================================================================================================
+
+#ifndef CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
+#define CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
+
+#include <vector>
+#include <string>
+
+#ifdef CLBLAST_REF_CLBLAS
+ #include "test/wrapper_clblas.hpp"
+#endif
+#ifdef CLBLAST_REF_CBLAS
+ #include "test/wrapper_cblas.hpp"
+#endif
+
+namespace clblast {
+// =================================================================================================
+
+// See comment at top of file for a description of the class
+template <typename T>
+class TestXgemmBatched {
+ public:
+
+ // Although it is a non-BLAS routine, it can still be tested against level-3 routines in a loop
+ static size_t BLASLevel() { return 3; }
+
+ // The list of arguments relevant for this routine
+ static std::vector<std::string> GetOptions() {
+ return {kArgM, kArgN, kArgK,
+ kArgLayout, kArgATransp, kArgBTransp,
+ kArgALeadDim, kArgBLeadDim, kArgCLeadDim,
+ kArgAOffset, kArgBOffset, kArgCOffset,
+ kArgBatchCount, kArgAlpha, kArgBeta};
+ }
+
+ // Helper for the sizes per batch
+ static size_t PerBatchSizeA(const Arguments<T> &args) {
+ auto a_rotated = (args.layout == Layout::kColMajor && args.a_transpose != Transpose::kNo) ||
+ (args.layout == Layout::kRowMajor && args.a_transpose == Transpose::kNo);
+ auto a_two = (a_rotated) ? args.m : args.k;
+ return a_two * args.a_ld;
+ }
+ static size_t PerBatchSizeB(const Arguments<T> &args) {
+ auto b_rotated = (args.layout == Layout::kColMajor && args.b_transpose != Transpose::kNo) ||
+ (args.layout == Layout::kRowMajor && args.b_transpose == Transpose::kNo);
+ auto b_two = (b_rotated) ? args.k : args.n;
+ return b_two * args.b_ld;
+ }
+ static size_t PerBatchSizeC(const Arguments<T> &args) {
+ auto c_rotated = (args.layout == Layout::kRowMajor);
+ auto c_two = (c_rotated) ? args.m : args.n;
+ return c_two * args.c_ld;
+ }
+
+ // Describes how to obtain the sizes of the buffers
+ static size_t GetSizeA(const Arguments<T> &args) {
+ return PerBatchSizeA(args) * args.batch_count + args.a_offset;
+ }
+ static size_t GetSizeB(const Arguments<T> &args) {
+ return PerBatchSizeB(args) * args.batch_count + args.b_offset;
+ }
+ static size_t GetSizeC(const Arguments<T> &args) {
+ return PerBatchSizeC(args) * args.batch_count + args.c_offset;
+ }
+
+ // Describes how to set the sizes of all the buffers
+ static void SetSizes(Arguments<T> &args) {
+ args.a_size = GetSizeA(args);
+ args.b_size = GetSizeB(args);
+ args.c_size = GetSizeC(args);
+
+ // Also sets the batch-related variables
+ args.a_offsets = std::vector<size_t>(args.batch_count);
+ args.b_offsets = std::vector<size_t>(args.batch_count);
+ args.c_offsets = std::vector<size_t>(args.batch_count);
+ args.alphas = std::vector<T>(args.batch_count);
+ args.betas = std::vector<T>(args.batch_count);
+ for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
+ args.a_offsets[batch] = batch * PerBatchSizeA(args) + args.a_offset;
+ args.b_offsets[batch] = batch * PerBatchSizeB(args) + args.b_offset;
+ args.c_offsets[batch] = batch * PerBatchSizeC(args) + args.c_offset;
+ args.alphas[batch] = args.alpha + Constant<T>(batch);
+ args.betas[batch] = args.beta + Constant<T>(batch);
+ }
+ }
+
+ // Describes what the default values of the leading dimensions of the matrices are
+ static size_t DefaultLDA(const Arguments<T> &args) { return args.k; }
+ static size_t DefaultLDB(const Arguments<T> &args) { return args.n; }
+ static size_t DefaultLDC(const Arguments<T> &args) { return args.n; }
+
+ // Describes which transpose options are relevant for this routine
+ using Transposes = std::vector<Transpose>;
+ static Transposes GetATransposes(const Transposes &all) { return all; }
+ static Transposes GetBTransposes(const Transposes &all) { return all; }
+
+ // Describes how to prepare the input data
+ static void PrepareData(const Arguments<T>&, Queue&, const int, std::vector<T>&,
+ std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&,
+ std::vector<T>&, std::vector<T>&) {} // N/A for this routine
+
+ // Describes how to run the CLBlast routine
+ static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ auto queue_plain = queue();
+ auto event = cl_event{};
+ auto status = GemmBatched(args.layout, args.a_transpose, args.b_transpose,
+ args.m, args.n, args.k, args.alphas.data(),
+ buffers.a_mat(), args.a_offsets.data(), args.a_ld,
+ buffers.b_mat(), args.b_offsets.data(), args.b_ld, args.betas.data(),
+ buffers.c_mat(), args.c_offsets.data(), args.c_ld,
+ args.batch_count,
+ &queue_plain, &event);
+ if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
+ return status;
+ }
+
+ // Describes how to run the clBLAS routine (for correctness/performance comparison)
+ #ifdef CLBLAST_REF_CLBLAS
+ static StatusCode RunReference1(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ auto queue_plain = queue();
+ for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
+ auto event = cl_event{};
+ auto status = clblasXgemm(convertToCLBLAS(args.layout),
+ convertToCLBLAS(args.a_transpose),
+ convertToCLBLAS(args.b_transpose),
+ args.m, args.n, args.k, args.alphas[batch],
+ buffers.a_mat, args.a_offsets[batch], args.a_ld,
+ buffers.b_mat, args.b_offsets[batch], args.b_ld, args.betas[batch],
+ buffers.c_mat, args.c_offsets[batch], args.c_ld,
+ 1, &queue_plain, 0, nullptr, &event);
+ clWaitForEvents(1, &event);
+ if (static_cast<StatusCode>(status) != StatusCode::kSuccess) {
+ return static_cast<StatusCode>(status);
+ }
+ }
+ return StatusCode::kSuccess;
+ }
+ #endif
+
+ // Describes how to run the CPU BLAS routine (for correctness/performance comparison)
+ #ifdef CLBLAST_REF_CBLAS
+ static StatusCode RunReference2(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ std::vector<T> a_mat_cpu(args.a_size, static_cast<T>(0));
+ std::vector<T> b_mat_cpu(args.b_size, static_cast<T>(0));
+ std::vector<T> c_mat_cpu(args.c_size, static_cast<T>(0));
+ buffers.a_mat.Read(queue, args.a_size, a_mat_cpu);
+ buffers.b_mat.Read(queue, args.b_size, b_mat_cpu);
+ buffers.c_mat.Read(queue, args.c_size, c_mat_cpu);
+ for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
+ cblasXgemm(convertToCBLAS(args.layout),
+ convertToCBLAS(args.a_transpose),
+ convertToCBLAS(args.b_transpose),
+ args.m, args.n, args.k, args.alphas[batch],
+ a_mat_cpu, args.a_offsets[batch], args.a_ld,
+ b_mat_cpu, args.b_offsets[batch], args.b_ld, args.betas[batch],
+ c_mat_cpu, args.c_offsets[batch], args.c_ld);
+ }
+ buffers.c_mat.Write(queue, args.c_size, c_mat_cpu);
+ return StatusCode::kSuccess;
+ }
+ #endif
+
+ // Describes how to download the results of the computation (more importantly: which buffer)
+ static std::vector<T> DownloadResult(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ std::vector<T> result(args.c_size, static_cast<T>(0));
+ buffers.c_mat.Read(queue, args.c_size, result);
+ return result;
+ }
+
+ // Describes how to compute the indices of the result buffer
+ static size_t ResultID1(const Arguments<T> &args) { return args.m; }
+ static size_t ResultID2(const Arguments<T> &args) { return args.n * args.batch_count; }
+ static size_t GetResultIndex(const Arguments<T> &args, const size_t id1, const size_t id2_3) {
+ const size_t id2 = id2_3 % args.n;
+ const size_t id3 = id2_3 / args.n;
+ return (args.layout == Layout::kRowMajor) ?
+ id1*args.c_ld + id2 + args.c_offsets[id3]:
+ id2*args.c_ld + id1 + args.c_offsets[id3];
+ }
+
+ // Describes how to compute performance metrics
+ static size_t GetFlops(const Arguments<T> &args) {
+ return args.batch_count * (2 * args.m * args.n * args.k);
+ }
+ static size_t GetBytes(const Arguments<T> &args) {
+ return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
+ }
+};
+
+// =================================================================================================
+} // namespace clblast
+
+// CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
+#endif