diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-07 14:27:15 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-07 14:27:15 +0100 |
commit | 9fb2c61b256ccf66b6a7b6f605008125288d60cf (patch) | |
tree | 2df0c0ed7a5be8e7f1b78131467e8620a2266da7 /src/clblast_cuda.cpp | |
parent | 0c48c6e6c4cd953523a10bcb804fde67e4650a57 (diff) |
Added API and tests for new GemmStridedBatched routine
Diffstat (limited to 'src/clblast_cuda.cpp')
-rw-r--r-- | src/clblast_cuda.cpp | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index 348ff3f5..0aa57087 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -2436,6 +2436,79 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, const size_t, const CUcontext, const CUdevice); +// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED +template <typename T> +StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const T alpha, + const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, + const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, + const T beta, + CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, + const size_t batch_count, + const CUcontext context, const CUdevice device) { + try { + const auto context_cpp = Context(context); + const auto device_cpp = Device(device); + auto queue_cpp = Queue(context_cpp, device_cpp); + auto routine = XgemmStridedBatched<T>(queue_cpp, nullptr); + routine.DoGemmStridedBatched(layout, a_transpose, b_transpose, + m, n, k, + alpha, + Buffer<T>(a_buffer), a_offset, a_ld, a_stride, + Buffer<T>(b_buffer), b_offset, b_ld, b_stride, + beta, + Buffer<T>(c_buffer), c_offset, c_ld, c_stride, + batch_count); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmStridedBatched<float>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const float, + const CUdeviceptr, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, const size_t, const size_t, + const float, + CUdeviceptr, const size_t, const size_t, const size_t, + const size_t, + const CUcontext, const CUdevice); +template StatusCode PUBLIC_API GemmStridedBatched<double>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const double, + const CUdeviceptr, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, const size_t, const size_t, + const double, + CUdeviceptr, const size_t, const size_t, const size_t, + const size_t, + const CUcontext, const CUdevice); +template StatusCode PUBLIC_API GemmStridedBatched<float2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const float2, + const CUdeviceptr, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, const size_t, const size_t, + const float2, + CUdeviceptr, const size_t, const size_t, const size_t, + const size_t, + const CUcontext, const CUdevice); +template StatusCode PUBLIC_API GemmStridedBatched<double2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const double2, + const CUdeviceptr, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, const size_t, const size_t, + const double2, + CUdeviceptr, const size_t, const size_t, const size_t, + const size_t, + const CUcontext, const CUdevice); +template StatusCode PUBLIC_API GemmStridedBatched<half>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const half, + const CUdeviceptr, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, const size_t, const size_t, + const half, + CUdeviceptr, const size_t, const size_t, const size_t, + const size_t, + const CUcontext, const CUdevice); + // ================================================================================================= // Retrieves the required size of the temporary buffer for the GEMM kernel (optional) |