summaryrefslogtreecommitdiff
path: root/src/clblast_cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/clblast_cuda.cpp')
-rw-r--r--src/clblast_cuda.cpp18
1 files changed, 11 insertions, 7 deletions
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 0e3d949d..187443eb 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -1725,19 +1725,23 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- const CUcontext context, const CUdevice device) {
+ const CUcontext context, const CUdevice device,
+ CUdeviceptr temp_buffer) {
try {
const auto context_cpp = Context(context);
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xgemm<T>(queue_cpp, nullptr);
+ const auto temp_buffer_provided = temp_buffer != nullptr;
+ auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);
routine.DoGemm(layout, a_transpose, b_transpose,
m, n, k,
alpha,
Buffer<T>(a_buffer), a_offset, a_ld,
Buffer<T>(b_buffer), b_offset, b_ld,
beta,
- Buffer<T>(c_buffer), c_offset, c_ld);
+ Buffer<T>(c_buffer), c_offset, c_ld,
+ temp_buffer_cpp, temp_buffer_provided);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
@@ -1748,7 +1752,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double,
@@ -1756,7 +1760,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2,
@@ -1764,7 +1768,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2,
@@ -1772,7 +1776,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half,
@@ -1780,7 +1784,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>