summaryrefslogtreecommitdiff
path: root/include/clblast_cuda.h
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-07 14:27:15 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-07 14:27:15 +0100
commit9fb2c61b256ccf66b6a7b6f605008125288d60cf (patch)
tree2df0c0ed7a5be8e7f1b78131467e8620a2266da7 /include/clblast_cuda.h
parent0c48c6e6c4cd953523a10bcb804fde67e4650a57 (diff)
Added API and tests for new GemmStridedBatched routine
Diffstat (limited to 'include/clblast_cuda.h')
-rw-r--r--include/clblast_cuda.h12
1 files changed, 12 insertions, 0 deletions
diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h
index e1237936..b0cb9aa8 100644
--- a/include/clblast_cuda.h
+++ b/include/clblast_cuda.h
@@ -619,6 +619,18 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
const size_t batch_count,
const CUcontext context, const CUdevice device);
+// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
+template <typename T>
+StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const T alpha,
+ const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
+ const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
+ const T beta,
+ CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
+ const size_t batch_count,
+ const CUcontext context, const CUdevice device);
+
// =================================================================================================
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)