// ================================================================================================= // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- // width of 100 characters per line. // // Author(s): // Cedric Nugteren // // This file implements all the BLAS API calls (CUDA version). In all cases, it does not much more // than creating a new object of the appropriate type, and calling the main routine on that object. // It forwards all status codes to the caller. // // ================================================================================================= #include #include "routines/routines.hpp" #include "clblast_cuda.h" namespace clblast { // ================================================================================================= // BLAS level-1 (vector-vector) routines // ================================================================================================= // Generate givens plane rotation: SROTG/DROTG template StatusCode Rotg(CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice) { return StatusCode::kNotImplemented; } template StatusCode PUBLIC_API Rotg(CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Rotg(CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Generate modified givens plane rotation: SROTMG/DROTMG template StatusCode Rotmg(CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice) { return StatusCode::kNotImplemented; } template StatusCode PUBLIC_API Rotmg(CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Rotmg(CUdeviceptr, const size_t, CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Apply givens plane rotation: SROT/DROT template StatusCode Rot(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const T, const T, const CUcontext, const CUdevice) { return StatusCode::kNotImplemented; } template StatusCode PUBLIC_API Rot(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const float, const float, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Rot(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const double, const double, const CUcontext, const CUdevice); // Apply modified givens plane rotation: SROTM/DROTM template StatusCode Rotm(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice) { return StatusCode::kNotImplemented; } template StatusCode PUBLIC_API Rotm(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Rotm(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP template StatusCode Swap(const size_t n, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xswap(queue_cpp, nullptr); routine.DoSwap(n, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Swap(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Swap(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Swap(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Swap(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Swap(const size_t, CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Vector scaling: SSCAL/DSCAL/CSCAL/ZSCAL/HSCAL template StatusCode Scal(const size_t n, const T alpha, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xscal(queue_cpp, nullptr); routine.DoScal(n, alpha, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Scal(const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Scal(const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Scal(const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Scal(const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Scal(const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Vector copy: SCOPY/DCOPY/CCOPY/ZCOPY/HCOPY template StatusCode Copy(const size_t n, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xcopy(queue_cpp, nullptr); routine.DoCopy(n, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Copy(const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Copy(const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Copy(const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Copy(const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Copy(const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY template StatusCode Axpy(const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xaxpy(queue_cpp, nullptr); routine.DoAxpy(n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Axpy(const size_t, const float, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Axpy(const size_t, const double, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Axpy(const size_t, const float2, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Axpy(const size_t, const double2, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Axpy(const size_t, const half, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Dot product of two vectors: SDOT/DDOT/HDOT template StatusCode Dot(const size_t n, CUdeviceptr dot_buffer, const size_t dot_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xdot(queue_cpp, nullptr); routine.DoDot(n, Buffer(dot_buffer), dot_offset, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Dot(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Dot(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Dot(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Dot product of two complex vectors: CDOTU/ZDOTU template StatusCode Dotu(const size_t n, CUdeviceptr dot_buffer, const size_t dot_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xdotu(queue_cpp, nullptr); routine.DoDotu(n, Buffer(dot_buffer), dot_offset, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Dotu(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Dotu(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Dot product of two complex vectors, one conjugated: CDOTC/ZDOTC template StatusCode Dotc(const size_t n, CUdeviceptr dot_buffer, const size_t dot_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xdotc(queue_cpp, nullptr); routine.DoDotc(n, Buffer(dot_buffer), dot_offset, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Dotc(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Dotc(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Euclidian norm of a vector: SNRM2/DNRM2/ScNRM2/DzNRM2/HNRM2 template StatusCode Nrm2(const size_t n, CUdeviceptr nrm2_buffer, const size_t nrm2_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xnrm2(queue_cpp, nullptr); routine.DoNrm2(n, Buffer(nrm2_buffer), nrm2_offset, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Nrm2(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Nrm2(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Nrm2(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Nrm2(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Nrm2(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Absolute sum of values in a vector: SASUM/DASUM/ScASUM/DzASUM/HASUM template StatusCode Asum(const size_t n, CUdeviceptr asum_buffer, const size_t asum_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xasum(queue_cpp, nullptr); routine.DoAsum(n, Buffer(asum_buffer), asum_offset, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Asum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Asum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Asum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Asum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Asum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Sum of values in a vector (non-BLAS function): SSUM/DSUM/ScSUM/DzSUM/HSUM template StatusCode Sum(const size_t n, CUdeviceptr sum_buffer, const size_t sum_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xsum(queue_cpp, nullptr); routine.DoSum(n, Buffer(sum_buffer), sum_offset, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Sum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Sum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Sum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Sum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Sum(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Index of absolute maximum value in a vector: iSAMAX/iDAMAX/iCAMAX/iZAMAX/iHAMAX template StatusCode Amax(const size_t n, CUdeviceptr imax_buffer, const size_t imax_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xamax(queue_cpp, nullptr); routine.DoAmax(n, Buffer(imax_buffer), imax_offset, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Amax(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Amax(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Amax(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Amax(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Amax(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Index of absolute minimum value in a vector (non-BLAS function): iSAMIN/iDAMIN/iCAMIN/iZAMIN/iHAMIN template StatusCode Amin(const size_t n, CUdeviceptr imin_buffer, const size_t imin_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xamin(queue_cpp, nullptr); routine.DoAmin(n, Buffer(imin_buffer), imin_offset, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Amin(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Amin(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Amin(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Amin(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Amin(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Index of maximum value in a vector (non-BLAS function): iSMAX/iDMAX/iCMAX/iZMAX/iHMAX template StatusCode Max(const size_t n, CUdeviceptr imax_buffer, const size_t imax_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xmax(queue_cpp, nullptr); routine.DoMax(n, Buffer(imax_buffer), imax_offset, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Max(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Max(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Max(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Max(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Max(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Index of minimum value in a vector (non-BLAS function): iSMIN/iDMIN/iCMIN/iZMIN/iHMIN template StatusCode Min(const size_t n, CUdeviceptr imin_buffer, const size_t imin_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xmin(queue_cpp, nullptr); routine.DoMin(n, Buffer(imin_buffer), imin_offset, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Min(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Min(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Min(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Min(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Min(const size_t, CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // ================================================================================================= // BLAS level-2 (matrix-vector) routines // ================================================================================================= // General matrix-vector multiplication: SGEMV/DGEMV/CGEMV/ZGEMV/HGEMV template StatusCode Gemv(const Layout layout, const Transpose a_transpose, const size_t m, const size_t n, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const T beta, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xgemv(queue_cpp, nullptr); routine.DoGemv(layout, a_transpose, m, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Gemv(const Layout, const Transpose, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gemv(const Layout, const Transpose, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gemv(const Layout, const Transpose, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gemv(const Layout, const Transpose, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gemv(const Layout, const Transpose, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // General banded matrix-vector multiplication: SGBMV/DGBMV/CGBMV/ZGBMV/HGBMV template StatusCode Gbmv(const Layout layout, const Transpose a_transpose, const size_t m, const size_t n, const size_t kl, const size_t ku, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const T beta, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xgbmv(queue_cpp, nullptr); routine.DoGbmv(layout, a_transpose, m, n, kl, ku, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Gbmv(const Layout, const Transpose, const size_t, const size_t, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gbmv(const Layout, const Transpose, const size_t, const size_t, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gbmv(const Layout, const Transpose, const size_t, const size_t, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gbmv(const Layout, const Transpose, const size_t, const size_t, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gbmv(const Layout, const Transpose, const size_t, const size_t, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Hermitian matrix-vector multiplication: CHEMV/ZHEMV template StatusCode Hemv(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const T beta, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xhemv(queue_cpp, nullptr); routine.DoHemv(layout, triangle, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Hemv(const Layout, const Triangle, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Hemv(const Layout, const Triangle, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Hermitian banded matrix-vector multiplication: CHBMV/ZHBMV template StatusCode Hbmv(const Layout layout, const Triangle triangle, const size_t n, const size_t k, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const T beta, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xhbmv(queue_cpp, nullptr); routine.DoHbmv(layout, triangle, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Hbmv(const Layout, const Triangle, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Hbmv(const Layout, const Triangle, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Hermitian packed matrix-vector multiplication: CHPMV/ZHPMV template StatusCode Hpmv(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr ap_buffer, const size_t ap_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const T beta, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xhpmv(queue_cpp, nullptr); routine.DoHpmv(layout, triangle, n, alpha, Buffer(ap_buffer), ap_offset, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Hpmv(const Layout, const Triangle, const size_t, const float2, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Hpmv(const Layout, const Triangle, const size_t, const double2, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Symmetric matrix-vector multiplication: SSYMV/DSYMV/HSYMV template StatusCode Symv(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const T beta, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xsymv(queue_cpp, nullptr); routine.DoSymv(layout, triangle, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Symv(const Layout, const Triangle, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Symv(const Layout, const Triangle, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Symv(const Layout, const Triangle, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Symmetric banded matrix-vector multiplication: SSBMV/DSBMV/HSBMV template StatusCode Sbmv(const Layout layout, const Triangle triangle, const size_t n, const size_t k, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const T beta, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xsbmv(queue_cpp, nullptr); routine.DoSbmv(layout, triangle, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Sbmv(const Layout, const Triangle, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Sbmv(const Layout, const Triangle, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Sbmv(const Layout, const Triangle, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Symmetric packed matrix-vector multiplication: SSPMV/DSPMV/HSPMV template StatusCode Spmv(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr ap_buffer, const size_t ap_offset, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const T beta, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xspmv(queue_cpp, nullptr); routine.DoSpmv(layout, triangle, n, alpha, Buffer(ap_buffer), ap_offset, Buffer(x_buffer), x_offset, x_inc, beta, Buffer(y_buffer), y_offset, y_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Spmv(const Layout, const Triangle, const size_t, const float, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Spmv(const Layout, const Triangle, const size_t, const double, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Spmv(const Layout, const Triangle, const size_t, const half, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Triangular matrix-vector multiplication: STRMV/DTRMV/CTRMV/ZTRMV/HTRMV template StatusCode Trmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t n, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xtrmv(queue_cpp, nullptr); routine.DoTrmv(layout, triangle, a_transpose, diagonal, n, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Trmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Triangular banded matrix-vector multiplication: STBMV/DTBMV/CTBMV/ZTBMV/HTBMV template StatusCode Tbmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t n, const size_t k, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xtbmv(queue_cpp, nullptr); routine.DoTbmv(layout, triangle, a_transpose, diagonal, n, k, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Tbmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tbmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tbmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tbmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tbmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Triangular packed matrix-vector multiplication: STPMV/DTPMV/CTPMV/ZTPMV/HTPMV template StatusCode Tpmv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t n, const CUdeviceptr ap_buffer, const size_t ap_offset, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xtpmv(queue_cpp, nullptr); routine.DoTpmv(layout, triangle, a_transpose, diagonal, n, Buffer(ap_buffer), ap_offset, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Tpmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tpmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tpmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tpmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tpmv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Solves a triangular system of equations: STRSV/DTRSV/CTRSV/ZTRSV template StatusCode Trsv(const Layout layout, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t n, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xtrsv(queue_cpp, nullptr); routine.DoTrsv(layout, triangle, a_transpose, diagonal, n, Buffer(a_buffer), a_offset, a_ld, Buffer(x_buffer), x_offset, x_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Trsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Solves a banded triangular system of equations: STBSV/DTBSV/CTBSV/ZTBSV template StatusCode Tbsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice) { return StatusCode::kNotImplemented; } template StatusCode PUBLIC_API Tbsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tbsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tbsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tbsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Solves a packed triangular system of equations: STPSV/DTPSV/CTPSV/ZTPSV template StatusCode Tpsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice) { return StatusCode::kNotImplemented; } template StatusCode PUBLIC_API Tpsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tpsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tpsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Tpsv(const Layout, const Triangle, const Transpose, const Diagonal, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // General rank-1 matrix update: SGER/DGER/HGER template StatusCode Ger(const Layout layout, const size_t m, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xger(queue_cpp, nullptr); routine.DoGer(layout, m, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc, Buffer(a_buffer), a_offset, a_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Ger(const Layout, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Ger(const Layout, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Ger(const Layout, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // General rank-1 complex matrix update: CGERU/ZGERU template StatusCode Geru(const Layout layout, const size_t m, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xgeru(queue_cpp, nullptr); routine.DoGeru(layout, m, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc, Buffer(a_buffer), a_offset, a_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Geru(const Layout, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Geru(const Layout, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // General rank-1 complex conjugated matrix update: CGERC/ZGERC template StatusCode Gerc(const Layout layout, const size_t m, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xgerc(queue_cpp, nullptr); routine.DoGerc(layout, m, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc, Buffer(a_buffer), a_offset, a_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Gerc(const Layout, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Gerc(const Layout, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Hermitian rank-1 matrix update: CHER/ZHER template StatusCode Her(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xher,T>(queue_cpp, nullptr); routine.DoHer(layout, triangle, n, alpha, Buffer>(x_buffer), x_offset, x_inc, Buffer>(a_buffer), a_offset, a_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Her(const Layout, const Triangle, const size_t, const float, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Her(const Layout, const Triangle, const size_t, const double, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Hermitian packed rank-1 matrix update: CHPR/ZHPR template StatusCode Hpr(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr ap_buffer, const size_t ap_offset, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xhpr,T>(queue_cpp, nullptr); routine.DoHpr(layout, triangle, n, alpha, Buffer>(x_buffer), x_offset, x_inc, Buffer>(ap_buffer), ap_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Hpr(const Layout, const Triangle, const size_t, const float, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Hpr(const Layout, const Triangle, const size_t, const double, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Hermitian rank-2 matrix update: CHER2/ZHER2 template StatusCode Her2(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xher2(queue_cpp, nullptr); routine.DoHer2(layout, triangle, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc, Buffer(a_buffer), a_offset, a_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Her2(const Layout, const Triangle, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Her2(const Layout, const Triangle, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Hermitian packed rank-2 matrix update: CHPR2/ZHPR2 template StatusCode Hpr2(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr ap_buffer, const size_t ap_offset, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xhpr2(queue_cpp, nullptr); routine.DoHpr2(layout, triangle, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc, Buffer(ap_buffer), ap_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Hpr2(const Layout, const Triangle, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Hpr2(const Layout, const Triangle, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Symmetric rank-1 matrix update: SSYR/DSYR/HSYR template StatusCode Syr(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xsyr(queue_cpp, nullptr); routine.DoSyr(layout, triangle, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(a_buffer), a_offset, a_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Syr(const Layout, const Triangle, const size_t, const float, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syr(const Layout, const Triangle, const size_t, const double, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syr(const Layout, const Triangle, const size_t, const half, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Symmetric packed rank-1 matrix update: SSPR/DSPR/HSPR template StatusCode Spr(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr ap_buffer, const size_t ap_offset, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xspr(queue_cpp, nullptr); routine.DoSpr(layout, triangle, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(ap_buffer), ap_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Spr(const Layout, const Triangle, const size_t, const float, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Spr(const Layout, const Triangle, const size_t, const double, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Spr(const Layout, const Triangle, const size_t, const half, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Symmetric rank-2 matrix update: SSYR2/DSYR2/HSYR2 template StatusCode Syr2(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xsyr2(queue_cpp, nullptr); routine.DoSyr2(layout, triangle, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc, Buffer(a_buffer), a_offset, a_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Syr2(const Layout, const Triangle, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syr2(const Layout, const Triangle, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syr2(const Layout, const Triangle, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Symmetric packed rank-2 matrix update: SSPR2/DSPR2/HSPR2 template StatusCode Spr2(const Layout layout, const Triangle triangle, const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr ap_buffer, const size_t ap_offset, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xspr2(queue_cpp, nullptr); routine.DoSpr2(layout, triangle, n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc, Buffer(ap_buffer), ap_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Spr2(const Layout, const Triangle, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Spr2(const Layout, const Triangle, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Spr2(const Layout, const Triangle, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // ================================================================================================= // BLAS level-3 (matrix-matrix) routines // ================================================================================================= // General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM template StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const CUcontext context, const CUdevice device, CUdeviceptr temp_buffer) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xgemm(queue_cpp, nullptr); const auto temp_buffer_provided = temp_buffer != 0; auto temp_buffer_cpp = temp_buffer_provided ? Buffer(temp_buffer) : Buffer(0); routine.DoGemm(layout, a_transpose, b_transpose, m, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, Buffer(c_buffer), c_offset, c_ld, temp_buffer_cpp, temp_buffer_provided); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice, CUdeviceptr); // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template StatusCode Symm(const Layout layout, const Side side, const Triangle triangle, const size_t m, const size_t n, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xsymm(queue_cpp, nullptr); routine.DoSymm(layout, side, triangle, m, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, Buffer(c_buffer), c_offset, c_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Symm(const Layout, const Side, const Triangle, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Symm(const Layout, const Side, const Triangle, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Symm(const Layout, const Side, const Triangle, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Symm(const Layout, const Side, const Triangle, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Symm(const Layout, const Side, const Triangle, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Hermitian matrix-matrix multiplication: CHEMM/ZHEMM template StatusCode Hemm(const Layout layout, const Side side, const Triangle triangle, const size_t m, const size_t n, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xhemm(queue_cpp, nullptr); routine.DoHemm(layout, side, triangle, m, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, Buffer(c_buffer), c_offset, c_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Hemm(const Layout, const Side, const Triangle, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Hemm(const Layout, const Side, const Triangle, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK template StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xsyrk(queue_cpp, nullptr); routine.DoSyrk(layout, triangle, a_transpose, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, beta, Buffer(c_buffer), c_offset, c_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Syrk(const Layout, const Triangle, const Transpose, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syrk(const Layout, const Triangle, const Transpose, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syrk(const Layout, const Triangle, const Transpose, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syrk(const Layout, const Triangle, const Transpose, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syrk(const Layout, const Triangle, const Transpose, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Rank-K update of a hermitian matrix: CHERK/ZHERK template StatusCode Herk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xherk,T>(queue_cpp, nullptr); routine.DoHerk(layout, triangle, a_transpose, n, k, alpha, Buffer>(a_buffer), a_offset, a_ld, beta, Buffer>(c_buffer), c_offset, c_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Herk(const Layout, const Triangle, const Transpose, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Herk(const Layout, const Triangle, const Transpose, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K template StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xsyr2k(queue_cpp, nullptr); routine.DoSyr2k(layout, triangle, ab_transpose, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, Buffer(c_buffer), c_offset, c_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Syr2k(const Layout, const Triangle, const Transpose, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syr2k(const Layout, const Triangle, const Transpose, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syr2k(const Layout, const Triangle, const Transpose, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syr2k(const Layout, const Triangle, const Transpose, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Syr2k(const Layout, const Triangle, const Transpose, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Rank-2K update of a hermitian matrix: CHER2K/ZHER2K template StatusCode Her2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const U beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xher2k(queue_cpp, nullptr); routine.DoHer2k(layout, triangle, ab_transpose, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, Buffer(c_buffer), c_offset, c_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Her2k(const Layout, const Triangle, const Transpose, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Her2k(const Layout, const Triangle, const Transpose, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM template StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xtrmm(queue_cpp, nullptr); routine.DoTrmm(layout, side, triangle, a_transpose, diagonal, m, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Trmm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trmm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trmm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trmm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trmm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM template StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xtrsm(queue_cpp, nullptr); routine.DoTrsm(layout, side, triangle, a_transpose, diagonal, m, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // ================================================================================================= // Extra non-BLAS routines (level-X) // ================================================================================================= // Element-wise vector product (Hadamard): SHAD/DHAD/CHAD/ZHAD/HHAD template StatusCode Had(const size_t n, const T alpha, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const T beta, CUdeviceptr z_buffer, const size_t z_offset, const size_t z_inc, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xhad(queue_cpp, nullptr); routine.DoHad(n, alpha, Buffer(x_buffer), x_offset, x_inc, Buffer(y_buffer), y_offset, y_inc, beta, Buffer(z_buffer), z_offset, z_inc); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Had(const size_t, const float, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Had(const size_t, const double, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Had(const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Had(const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Had(const size_t, const half, const CUdeviceptr, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Scaling and out-place transpose/copy (non-BLAS function): SOMATCOPY/DOMATCOPY/COMATCOPY/ZOMATCOPY/HOMATCOPY template StatusCode Omatcopy(const Layout layout, const Transpose a_transpose, const size_t m, const size_t n, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xomatcopy(queue_cpp, nullptr); routine.DoOmatcopy(layout, a_transpose, m, n, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Omatcopy(const Layout, const Transpose, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Omatcopy(const Layout, const Transpose, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Omatcopy(const Layout, const Transpose, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Omatcopy(const Layout, const Transpose, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Omatcopy(const Layout, const Transpose, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, CUdeviceptr, const size_t, const size_t, const CUcontext, const CUdevice); // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template StatusCode Im2col(const KernelMode kernel_mode, const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr im_buffer, const size_t im_offset, CUdeviceptr col_buffer, const size_t col_offset, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xim2col(queue_cpp, nullptr); routine.DoIm2col(kernel_mode, channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer(im_buffer), im_offset, Buffer(col_buffer), col_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Im2col(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Im2col(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Im2col(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Im2col(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Im2col(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template StatusCode Col2im(const KernelMode kernel_mode, const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr col_buffer, const size_t col_offset, CUdeviceptr im_buffer, const size_t im_offset, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xcol2im(queue_cpp, nullptr); routine.DoCol2im(kernel_mode, channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer(col_buffer), col_offset, Buffer(im_buffer), im_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Col2im(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Col2im(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Col2im(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Col2im(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Col2im(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template StatusCode Convgemm(const KernelMode kernel_mode, const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const CUdeviceptr im_buffer, const size_t im_offset, const CUdeviceptr kernel_buffer, const size_t kernel_offset, CUdeviceptr result_buffer, const size_t result_offset, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xconvgemm(queue_cpp, nullptr); routine.DoConvgemm(kernel_mode, channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, Buffer(im_buffer), im_offset, Buffer(kernel_buffer), kernel_offset, Buffer(result_buffer), result_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API Convgemm(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Convgemm(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API Convgemm(const KernelMode, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED template StatusCode AxpyBatched(const size_t n, const T *alphas, const CUdeviceptr x_buffer, const size_t *x_offsets, const size_t x_inc, CUdeviceptr y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = XaxpyBatched(queue_cpp, nullptr); auto alphas_cpp = std::vector(); auto x_offsets_cpp = std::vector(); auto y_offsets_cpp = std::vector(); for (auto batch = size_t{0}; batch < batch_count; ++batch) { alphas_cpp.push_back(alphas[batch]); x_offsets_cpp.push_back(x_offsets[batch]); y_offsets_cpp.push_back(y_offsets[batch]); } routine.DoAxpyBatched(n, alphas_cpp, Buffer(x_buffer), x_offsets_cpp, x_inc, Buffer(y_buffer), y_offsets_cpp, y_inc, batch_count); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API AxpyBatched(const size_t, const float*, const CUdeviceptr, const size_t*, const size_t, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API AxpyBatched(const size_t, const double*, const CUdeviceptr, const size_t*, const size_t, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API AxpyBatched(const size_t, const float2*, const CUdeviceptr, const size_t*, const size_t, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API AxpyBatched(const size_t, const double2*, const CUdeviceptr, const size_t*, const size_t, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API AxpyBatched(const size_t, const half*, const CUdeviceptr, const size_t*, const size_t, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); // Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED template StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, const T *alphas, const CUdeviceptr a_buffer, const size_t *a_offsets, const size_t a_ld, const CUdeviceptr b_buffer, const size_t *b_offsets, const size_t b_ld, const T *betas, CUdeviceptr c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = XgemmBatched(queue_cpp, nullptr); auto alphas_cpp = std::vector(); auto betas_cpp = std::vector(); auto a_offsets_cpp = std::vector(); auto b_offsets_cpp = std::vector(); auto c_offsets_cpp = std::vector(); for (auto batch = size_t{0}; batch < batch_count; ++batch) { alphas_cpp.push_back(alphas[batch]); betas_cpp.push_back(betas[batch]); a_offsets_cpp.push_back(a_offsets[batch]); b_offsets_cpp.push_back(b_offsets[batch]); c_offsets_cpp.push_back(c_offsets[batch]); } routine.DoGemmBatched(layout, a_transpose, b_transpose, m, n, k, alphas_cpp, Buffer(a_buffer), a_offsets_cpp, a_ld, Buffer(b_buffer), b_offsets_cpp, b_ld, betas_cpp, Buffer(c_buffer), c_offsets_cpp, c_ld, batch_count); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float*, const CUdeviceptr, const size_t*, const size_t, const CUdeviceptr, const size_t*, const size_t, const float*, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double*, const CUdeviceptr, const size_t*, const size_t, const CUdeviceptr, const size_t*, const size_t, const double*, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float2*, const CUdeviceptr, const size_t*, const size_t, const CUdeviceptr, const size_t*, const size_t, const float2*, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double2*, const CUdeviceptr, const size_t*, const size_t, const CUdeviceptr, const size_t*, const size_t, const double2*, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const half*, const CUdeviceptr, const size_t*, const size_t, const CUdeviceptr, const size_t*, const size_t, const half*, CUdeviceptr, const size_t*, const size_t, const size_t, const CUcontext, const CUdevice); // StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED template StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, const T alpha, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count, const CUcontext context, const CUdevice device) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = XgemmStridedBatched(queue_cpp, nullptr); routine.DoGemmStridedBatched(layout, a_transpose, b_transpose, m, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, a_stride, Buffer(b_buffer), b_offset, b_ld, b_stride, beta, Buffer(c_buffer), c_offset, c_ld, c_stride, batch_count); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float, const CUdeviceptr, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double, const CUdeviceptr, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float2, const CUdeviceptr, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double2, const CUdeviceptr, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, const size_t, const size_t, const CUcontext, const CUdevice); template StatusCode PUBLIC_API GemmStridedBatched(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const half, const CUdeviceptr, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, const size_t, const size_t, const CUcontext, const CUdevice); // ================================================================================================= // Retrieves the required size of the temporary buffer for the GEMM kernel (optional) template StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, const size_t a_offset, const size_t a_ld, const size_t b_offset, const size_t b_ld, const size_t c_offset, const size_t c_ld, const CUdevice device, size_t& temp_buffer_size) { try { // Retrieves the tuning database const auto device_cpp = Device(device); const auto kernel_names = std::vector{"Xgemm", "GemmRoutine"}; Databases db(kernel_names); Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue(), {}, db); // Computes the buffer size if (Xgemm::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) { temp_buffer_size = 0; } else { temp_buffer_size = Xgemm::GetTempSize(layout, a_transpose, b_transpose, m, n, k, a_offset, a_ld, b_offset, b_ld, c_offset, c_ld, db["MWG"], db["NWG"], db["KWG"] * db["KREG"], db["GEMMK"]); } temp_buffer_size *= sizeof(T); // translate from num-elements to bytes return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdevice, size_t&); template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdevice, size_t&); template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdevice, size_t&); template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdevice, size_t&); template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdevice, size_t&); // ================================================================================================= } // namespace clblast