summaryrefslogtreecommitdiff
path: root/src/clblast_cuda.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-07 14:27:15 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-07 14:27:15 +0100
commit9fb2c61b256ccf66b6a7b6f605008125288d60cf (patch)
tree2df0c0ed7a5be8e7f1b78131467e8620a2266da7 /src/clblast_cuda.cpp
parent0c48c6e6c4cd953523a10bcb804fde67e4650a57 (diff)
Added API and tests for new GemmStridedBatched routine
Diffstat (limited to 'src/clblast_cuda.cpp')
-rw-r--r--src/clblast_cuda.cpp73
1 files changed, 73 insertions, 0 deletions
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 348ff3f5..0aa57087 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -2436,6 +2436,79 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
const size_t,
const CUcontext, const CUdevice);
+// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
+template <typename T>
+StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const T alpha,
+ const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
+ const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
+ const T beta,
+ CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
+ const size_t batch_count,
+ const CUcontext context, const CUdevice device) {
+ try {
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
+ auto routine = XgemmStridedBatched<T>(queue_cpp, nullptr);
+ routine.DoGemmStridedBatched(layout, a_transpose, b_transpose,
+ m, n, k,
+ alpha,
+ Buffer<T>(a_buffer), a_offset, a_ld, a_stride,
+ Buffer<T>(b_buffer), b_offset, b_ld, b_stride,
+ beta,
+ Buffer<T>(c_buffer), c_offset, c_ld, c_stride,
+ batch_count);
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmStridedBatched<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const float,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const float,
+ CUdeviceptr, const size_t, const size_t, const size_t,
+ const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API GemmStridedBatched<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const double,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const double,
+ CUdeviceptr, const size_t, const size_t, const size_t,
+ const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API GemmStridedBatched<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const float2,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const float2,
+ CUdeviceptr, const size_t, const size_t, const size_t,
+ const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API GemmStridedBatched<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const double2,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const double2,
+ CUdeviceptr, const size_t, const size_t, const size_t,
+ const size_t,
+ const CUcontext, const CUdevice);
+template StatusCode PUBLIC_API GemmStridedBatched<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const half,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const CUdeviceptr, const size_t, const size_t, const size_t,
+ const half,
+ CUdeviceptr, const size_t, const size_t, const size_t,
+ const size_t,
+ const CUcontext, const CUdevice);
+
// =================================================================================================
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)