summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-10-12 12:20:43 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-10-12 12:20:43 +0200
commitcc5b4754250b3c03b9b0f8d72f32d1eacac15b18 (patch)
tree747a5ad136f708de3559c061243e5f31bc17977a /src
parentb901809345848b44442c787380b13db5e5156df0 (diff)
CUDA API now takes context and device in instead of stream
Diffstat (limited to 'src')
-rw-r--r--src/clblast_cuda.cpp720
-rw-r--r--src/utilities/buffer_test.hpp2
2 files changed, 411 insertions, 311 deletions
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 5f30d023..f9a24236 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -30,19 +30,19 @@ StatusCode Rotg(CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream* stream) {
+ const CUcontext, const CUdevice) {
return StatusCode::kNotImplemented;
}
template StatusCode PUBLIC_API Rotg<float>(CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Rotg<double>(CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Generate modified givens plane rotation: SROTMG/DROTMG
template <typename T>
@@ -51,7 +51,7 @@ StatusCode Rotmg(CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream* stream) {
+ const CUcontext, const CUdevice) {
return StatusCode::kNotImplemented;
}
template StatusCode PUBLIC_API Rotmg<float>(CUdeviceptr, const size_t,
@@ -59,13 +59,13 @@ template StatusCode PUBLIC_API Rotmg<float>(CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Rotmg<double>(CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Apply givens plane rotation: SROT/DROT
template <typename T>
@@ -74,7 +74,7 @@ StatusCode Rot(const size_t,
CUdeviceptr, const size_t, const size_t,
const T,
const T,
- CUstream* stream) {
+ const CUcontext, const CUdevice) {
return StatusCode::kNotImplemented;
}
template StatusCode PUBLIC_API Rot<float>(const size_t,
@@ -82,13 +82,13 @@ template StatusCode PUBLIC_API Rot<float>(const size_t,
CUdeviceptr, const size_t, const size_t,
const float,
const float,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Rot<double>(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
const double,
const double,
- CUstream*);
+ const CUcontext, const CUdevice);
// Apply modified givens plane rotation: SROTM/DROTM
template <typename T>
@@ -96,28 +96,30 @@ StatusCode Rotm(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream* stream) {
+ const CUcontext, const CUdevice) {
return StatusCode::kNotImplemented;
}
template StatusCode PUBLIC_API Rotm<float>(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Rotm<double>(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP
template <typename T>
StatusCode Swap(const size_t n,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xswap<T>(queue_cpp, event);
routine.DoSwap(n,
Buffer<T>(x_buffer), x_offset, x_inc,
@@ -128,32 +130,34 @@ StatusCode Swap(const size_t n,
template StatusCode PUBLIC_API Swap<float>(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Swap<double>(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Swap<float2>(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Swap<double2>(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Swap<half>(const size_t,
CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Vector scaling: SSCAL/DSCAL/CSCAL/ZSCAL/HSCAL
template <typename T>
StatusCode Scal(const size_t n,
const T alpha,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xscal<T>(queue_cpp, event);
routine.DoScal(n,
alpha,
@@ -164,32 +168,34 @@ StatusCode Scal(const size_t n,
template StatusCode PUBLIC_API Scal<float>(const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Scal<double>(const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Scal<float2>(const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Scal<double2>(const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Scal<half>(const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Vector copy: SCOPY/DCOPY/CCOPY/ZCOPY/HCOPY
template <typename T>
StatusCode Copy(const size_t n,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xcopy<T>(queue_cpp, event);
routine.DoCopy(n,
Buffer<T>(x_buffer), x_offset, x_inc,
@@ -200,23 +206,23 @@ StatusCode Copy(const size_t n,
template StatusCode PUBLIC_API Copy<float>(const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Copy<double>(const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Copy<float2>(const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Copy<double2>(const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Copy<half>(const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY
template <typename T>
@@ -224,9 +230,11 @@ StatusCode Axpy(const size_t n,
const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xaxpy<T>(queue_cpp, event);
routine.DoAxpy(n,
alpha,
@@ -239,27 +247,27 @@ template StatusCode PUBLIC_API Axpy<float>(const size_t,
const float,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Axpy<double>(const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Axpy<float2>(const size_t,
const float2,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Axpy<double2>(const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Axpy<half>(const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Dot product of two vectors: SDOT/DDOT/HDOT
template <typename T>
@@ -267,9 +275,11 @@ StatusCode Dot(const size_t n,
CUdeviceptr dot_buffer, const size_t dot_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xdot<T>(queue_cpp, event);
routine.DoDot(n,
Buffer<T>(dot_buffer), dot_offset,
@@ -282,17 +292,17 @@ template StatusCode PUBLIC_API Dot<float>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Dot<double>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Dot<half>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Dot product of two complex vectors: CDOTU/ZDOTU
template <typename T>
@@ -300,9 +310,11 @@ StatusCode Dotu(const size_t n,
CUdeviceptr dot_buffer, const size_t dot_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xdotu<T>(queue_cpp, event);
routine.DoDotu(n,
Buffer<T>(dot_buffer), dot_offset,
@@ -315,12 +327,12 @@ template StatusCode PUBLIC_API Dotu<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Dotu<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Dot product of two complex vectors, one conjugated: CDOTC/ZDOTC
template <typename T>
@@ -328,9 +340,11 @@ StatusCode Dotc(const size_t n,
CUdeviceptr dot_buffer, const size_t dot_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xdotc<T>(queue_cpp, event);
routine.DoDotc(n,
Buffer<T>(dot_buffer), dot_offset,
@@ -343,21 +357,23 @@ template StatusCode PUBLIC_API Dotc<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Dotc<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Euclidian norm of a vector: SNRM2/DNRM2/ScNRM2/DzNRM2/HNRM2
template <typename T>
StatusCode Nrm2(const size_t n,
CUdeviceptr nrm2_buffer, const size_t nrm2_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xnrm2<T>(queue_cpp, event);
routine.DoNrm2(n,
Buffer<T>(nrm2_buffer), nrm2_offset,
@@ -368,32 +384,34 @@ StatusCode Nrm2(const size_t n,
template StatusCode PUBLIC_API Nrm2<float>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Nrm2<double>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Nrm2<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Nrm2<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Nrm2<half>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Absolute sum of values in a vector: SASUM/DASUM/ScASUM/DzASUM/HASUM
template <typename T>
StatusCode Asum(const size_t n,
CUdeviceptr asum_buffer, const size_t asum_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xasum<T>(queue_cpp, event);
routine.DoAsum(n,
Buffer<T>(asum_buffer), asum_offset,
@@ -404,32 +422,34 @@ StatusCode Asum(const size_t n,
template StatusCode PUBLIC_API Asum<float>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Asum<double>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Asum<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Asum<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Asum<half>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Sum of values in a vector (non-BLAS function): SSUM/DSUM/ScSUM/DzSUM/HSUM
template <typename T>
StatusCode Sum(const size_t n,
CUdeviceptr sum_buffer, const size_t sum_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xsum<T>(queue_cpp, event);
routine.DoSum(n,
Buffer<T>(sum_buffer), sum_offset,
@@ -440,32 +460,34 @@ StatusCode Sum(const size_t n,
template StatusCode PUBLIC_API Sum<float>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Sum<double>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Sum<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Sum<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Sum<half>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Index of absolute maximum value in a vector: iSAMAX/iDAMAX/iCAMAX/iZAMAX/iHAMAX
template <typename T>
StatusCode Amax(const size_t n,
CUdeviceptr imax_buffer, const size_t imax_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xamax<T>(queue_cpp, event);
routine.DoAmax(n,
Buffer<unsigned int>(imax_buffer), imax_offset,
@@ -476,32 +498,34 @@ StatusCode Amax(const size_t n,
template StatusCode PUBLIC_API Amax<float>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Amax<double>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Amax<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Amax<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Amax<half>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Index of absolute minimum value in a vector (non-BLAS function): iSAMIN/iDAMIN/iCAMIN/iZAMIN/iHAMIN
template <typename T>
StatusCode Amin(const size_t n,
CUdeviceptr imin_buffer, const size_t imin_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xamin<T>(queue_cpp, event);
routine.DoAmin(n,
Buffer<unsigned int>(imin_buffer), imin_offset,
@@ -512,32 +536,34 @@ StatusCode Amin(const size_t n,
template StatusCode PUBLIC_API Amin<float>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Amin<double>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Amin<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Amin<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Amin<half>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Index of maximum value in a vector (non-BLAS function): iSMAX/iDMAX/iCMAX/iZMAX/iHMAX
template <typename T>
StatusCode Max(const size_t n,
CUdeviceptr imax_buffer, const size_t imax_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xmax<T>(queue_cpp, event);
routine.DoMax(n,
Buffer<unsigned int>(imax_buffer), imax_offset,
@@ -548,32 +574,34 @@ StatusCode Max(const size_t n,
template StatusCode PUBLIC_API Max<float>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Max<double>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Max<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Max<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Max<half>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Index of minimum value in a vector (non-BLAS function): iSMIN/iDMIN/iCMIN/iZMIN/iHMIN
template <typename T>
StatusCode Min(const size_t n,
CUdeviceptr imin_buffer, const size_t imin_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xmin<T>(queue_cpp, event);
routine.DoMin(n,
Buffer<unsigned int>(imin_buffer), imin_offset,
@@ -584,23 +612,23 @@ StatusCode Min(const size_t n,
template StatusCode PUBLIC_API Min<float>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Min<double>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Min<float2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Min<double2>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Min<half>(const size_t,
CUdeviceptr, const size_t,
const CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// =================================================================================================
// BLAS level-2 (matrix-vector) routines
@@ -615,9 +643,11 @@ StatusCode Gemv(const Layout layout, const Transpose a_transpose,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xgemv<T>(queue_cpp, event);
routine.DoGemv(layout, a_transpose,
m, n,
@@ -636,7 +666,7 @@ template StatusCode PUBLIC_API Gemv<float>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gemv<double>(const Layout, const Transpose,
const size_t, const size_t,
const double,
@@ -644,7 +674,7 @@ template StatusCode PUBLIC_API Gemv<double>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gemv<float2>(const Layout, const Transpose,
const size_t, const size_t,
const float2,
@@ -652,7 +682,7 @@ template StatusCode PUBLIC_API Gemv<float2>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gemv<double2>(const Layout, const Transpose,
const size_t, const size_t,
const double2,
@@ -660,7 +690,7 @@ template StatusCode PUBLIC_API Gemv<double2>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gemv<half>(const Layout, const Transpose,
const size_t, const size_t,
const half,
@@ -668,7 +698,7 @@ template StatusCode PUBLIC_API Gemv<half>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// General banded matrix-vector multiplication: SGBMV/DGBMV/CGBMV/ZGBMV/HGBMV
template <typename T>
@@ -679,9 +709,11 @@ StatusCode Gbmv(const Layout layout, const Transpose a_transpose,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xgbmv<T>(queue_cpp, event);
routine.DoGbmv(layout, a_transpose,
m, n, kl, ku,
@@ -700,7 +732,7 @@ template StatusCode PUBLIC_API Gbmv<float>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gbmv<double>(const Layout, const Transpose,
const size_t, const size_t, const size_t, const size_t,
const double,
@@ -708,7 +740,7 @@ template StatusCode PUBLIC_API Gbmv<double>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gbmv<float2>(const Layout, const Transpose,
const size_t, const size_t, const size_t, const size_t,
const float2,
@@ -716,7 +748,7 @@ template StatusCode PUBLIC_API Gbmv<float2>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gbmv<double2>(const Layout, const Transpose,
const size_t, const size_t, const size_t, const size_t,
const double2,
@@ -724,7 +756,7 @@ template StatusCode PUBLIC_API Gbmv<double2>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gbmv<half>(const Layout, const Transpose,
const size_t, const size_t, const size_t, const size_t,
const half,
@@ -732,7 +764,7 @@ template StatusCode PUBLIC_API Gbmv<half>(const Layout, const Transpose,
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Hermitian matrix-vector multiplication: CHEMV/ZHEMV
template <typename T>
@@ -743,9 +775,11 @@ StatusCode Hemv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xhemv<T>(queue_cpp, event);
routine.DoHemv(layout, triangle,
n,
@@ -764,7 +798,7 @@ template StatusCode PUBLIC_API Hemv<float2>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Hemv<double2>(const Layout, const Triangle,
const size_t,
const double2,
@@ -772,7 +806,7 @@ template StatusCode PUBLIC_API Hemv<double2>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Hermitian banded matrix-vector multiplication: CHBMV/ZHBMV
template <typename T>
@@ -783,9 +817,11 @@ StatusCode Hbmv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xhbmv<T>(queue_cpp, event);
routine.DoHbmv(layout, triangle,
n, k,
@@ -804,7 +840,7 @@ template StatusCode PUBLIC_API Hbmv<float2>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Hbmv<double2>(const Layout, const Triangle,
const size_t, const size_t,
const double2,
@@ -812,7 +848,7 @@ template StatusCode PUBLIC_API Hbmv<double2>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Hermitian packed matrix-vector multiplication: CHPMV/ZHPMV
template <typename T>
@@ -823,9 +859,11 @@ StatusCode Hpmv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xhpmv<T>(queue_cpp, event);
routine.DoHpmv(layout, triangle,
n,
@@ -844,7 +882,7 @@ template StatusCode PUBLIC_API Hpmv<float2>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Hpmv<double2>(const Layout, const Triangle,
const size_t,
const double2,
@@ -852,7 +890,7 @@ template StatusCode PUBLIC_API Hpmv<double2>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Symmetric matrix-vector multiplication: SSYMV/DSYMV/HSYMV
template <typename T>
@@ -863,9 +901,11 @@ StatusCode Symv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xsymv<T>(queue_cpp, event);
routine.DoSymv(layout, triangle,
n,
@@ -884,7 +924,7 @@ template StatusCode PUBLIC_API Symv<float>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Symv<double>(const Layout, const Triangle,
const size_t,
const double,
@@ -892,7 +932,7 @@ template StatusCode PUBLIC_API Symv<double>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Symv<half>(const Layout, const Triangle,
const size_t,
const half,
@@ -900,7 +940,7 @@ template StatusCode PUBLIC_API Symv<half>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Symmetric banded matrix-vector multiplication: SSBMV/DSBMV/HSBMV
template <typename T>
@@ -911,9 +951,11 @@ StatusCode Sbmv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xsbmv<T>(queue_cpp, event);
routine.DoSbmv(layout, triangle,
n, k,
@@ -932,7 +974,7 @@ template StatusCode PUBLIC_API Sbmv<float>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Sbmv<double>(const Layout, const Triangle,
const size_t, const size_t,
const double,
@@ -940,7 +982,7 @@ template StatusCode PUBLIC_API Sbmv<double>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Sbmv<half>(const Layout, const Triangle,
const size_t, const size_t,
const half,
@@ -948,7 +990,7 @@ template StatusCode PUBLIC_API Sbmv<half>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Symmetric packed matrix-vector multiplication: SSPMV/DSPMV/HSPMV
template <typename T>
@@ -959,9 +1001,11 @@ StatusCode Spmv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xspmv<T>(queue_cpp, event);
routine.DoSpmv(layout, triangle,
n,
@@ -980,7 +1024,7 @@ template StatusCode PUBLIC_API Spmv<float>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Spmv<double>(const Layout, const Triangle,
const size_t,
const double,
@@ -988,7 +1032,7 @@ template StatusCode PUBLIC_API Spmv<double>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Spmv<half>(const Layout, const Triangle,
const size_t,
const half,
@@ -996,7 +1040,7 @@ template StatusCode PUBLIC_API Spmv<half>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Triangular matrix-vector multiplication: STRMV/DTRMV/CTRMV/ZTRMV/HTRMV
template <typename T>
@@ -1004,9 +1048,11 @@ StatusCode Trmv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xtrmv<T>(queue_cpp, event);
routine.DoTrmv(layout, triangle, a_transpose, diagonal,
n,
@@ -1019,27 +1065,27 @@ template StatusCode PUBLIC_API Trmv<float>(const Layout, const Triangle, const T
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trmv<double>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trmv<float2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trmv<double2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trmv<half>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Triangular banded matrix-vector multiplication: STBMV/DTBMV/CTBMV/ZTBMV/HTBMV
template <typename T>
@@ -1047,9 +1093,11 @@ StatusCode Tbmv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n, const size_t k,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xtbmv<T>(queue_cpp, event);
routine.DoTbmv(layout, triangle, a_transpose, diagonal,
n, k,
@@ -1062,27 +1110,27 @@ template StatusCode PUBLIC_API Tbmv<float>(const Layout, const Triangle, const T
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tbmv<double>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tbmv<float2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tbmv<double2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tbmv<half>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Triangular packed matrix-vector multiplication: STPMV/DTPMV/CTPMV/ZTPMV/HTPMV
template <typename T>
@@ -1090,9 +1138,11 @@ StatusCode Tpmv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n,
const CUdeviceptr ap_buffer, const size_t ap_offset,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xtpmv<T>(queue_cpp, event);
routine.DoTpmv(layout, triangle, a_transpose, diagonal,
n,
@@ -1105,27 +1155,27 @@ template StatusCode PUBLIC_API Tpmv<float>(const Layout, const Triangle, const T
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tpmv<double>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tpmv<float2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tpmv<double2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tpmv<half>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Solves a triangular system of equations: STRSV/DTRSV/CTRSV/ZTRSV
template <typename T>
@@ -1133,9 +1183,11 @@ StatusCode Trsv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xtrsv<T>(queue_cpp, event);
routine.DoTrsv(layout, triangle, a_transpose, diagonal,
n,
@@ -1148,22 +1200,22 @@ template StatusCode PUBLIC_API Trsv<float>(const Layout, const Triangle, const T
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trsv<double>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trsv<float2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trsv<double2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Solves a banded triangular system of equations: STBSV/DTBSV/CTBSV/ZTBSV
template <typename T>
@@ -1171,29 +1223,29 @@ StatusCode Tbsv(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream* stream) {
+ const CUcontext, const CUdevice) {
return StatusCode::kNotImplemented;
}
template StatusCode PUBLIC_API Tbsv<float>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tbsv<double>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tbsv<float2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tbsv<double2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Solves a packed triangular system of equations: STPSV/DTPSV/CTPSV/ZTPSV
template <typename T>
@@ -1201,29 +1253,29 @@ StatusCode Tpsv(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream* stream) {
+ const CUcontext, const CUdevice) {
return StatusCode::kNotImplemented;
}
template StatusCode PUBLIC_API Tpsv<float>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tpsv<double>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tpsv<float2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Tpsv<double2>(const Layout, const Triangle, const Transpose, const Diagonal,
const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// General rank-1 matrix update: SGER/DGER/HGER
template <typename T>
@@ -1233,9 +1285,11 @@ StatusCode Ger(const Layout layout,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xger<T>(queue_cpp, event);
routine.DoGer(layout,
m, n,
@@ -1252,21 +1306,21 @@ template StatusCode PUBLIC_API Ger<float>(const Layout,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Ger<double>(const Layout,
const size_t, const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Ger<half>(const Layout,
const size_t, const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// General rank-1 complex matrix update: CGERU/ZGERU
template <typename T>
@@ -1276,9 +1330,11 @@ StatusCode Geru(const Layout layout,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xgeru<T>(queue_cpp, event);
routine.DoGeru(layout,
m, n,
@@ -1295,14 +1351,14 @@ template StatusCode PUBLIC_API Geru<float2>(const Layout,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Geru<double2>(const Layout,
const size_t, const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// General rank-1 complex conjugated matrix update: CGERC/ZGERC
template <typename T>
@@ -1312,9 +1368,11 @@ StatusCode Gerc(const Layout layout,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xgerc<T>(queue_cpp, event);
routine.DoGerc(layout,
m, n,
@@ -1331,14 +1389,14 @@ template StatusCode PUBLIC_API Gerc<float2>(const Layout,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gerc<double2>(const Layout,
const size_t, const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Hermitian rank-1 matrix update: CHER/ZHER
template <typename T>
@@ -1347,9 +1405,11 @@ StatusCode Her(const Layout layout, const Triangle triangle,
const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xher<std::complex<T>,T>(queue_cpp, event);
routine.DoHer(layout, triangle,
n,
@@ -1364,13 +1424,13 @@ template StatusCode PUBLIC_API Her<float>(const Layout, const Triangle,
const float,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Her<double>(const Layout, const Triangle,
const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Hermitian packed rank-1 matrix update: CHPR/ZHPR
template <typename T>
@@ -1379,9 +1439,11 @@ StatusCode Hpr(const Layout layout, const Triangle triangle,
const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr ap_buffer, const size_t ap_offset,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xhpr<std::complex<T>,T>(queue_cpp, event);
routine.DoHpr(layout, triangle,
n,
@@ -1396,13 +1458,13 @@ template StatusCode PUBLIC_API Hpr<float>(const Layout, const Triangle,
const float,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Hpr<double>(const Layout, const Triangle,
const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Hermitian rank-2 matrix update: CHER2/ZHER2
template <typename T>
@@ -1412,9 +1474,11 @@ StatusCode Her2(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xher2<T>(queue_cpp, event);
routine.DoHer2(layout, triangle,
n,
@@ -1431,14 +1495,14 @@ template StatusCode PUBLIC_API Her2<float2>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Her2<double2>(const Layout, const Triangle,
const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Hermitian packed rank-2 matrix update: CHPR2/ZHPR2
template <typename T>
@@ -1448,9 +1512,11 @@ StatusCode Hpr2(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr ap_buffer, const size_t ap_offset,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xhpr2<T>(queue_cpp, event);
routine.DoHpr2(layout, triangle,
n,
@@ -1467,14 +1533,14 @@ template StatusCode PUBLIC_API Hpr2<float2>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Hpr2<double2>(const Layout, const Triangle,
const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Symmetric rank-1 matrix update: SSYR/DSYR/HSYR
template <typename T>
@@ -1483,9 +1549,11 @@ StatusCode Syr(const Layout layout, const Triangle triangle,
const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xsyr<T>(queue_cpp, event);
routine.DoSyr(layout, triangle,
n,
@@ -1500,19 +1568,19 @@ template StatusCode PUBLIC_API Syr<float>(const Layout, const Triangle,
const float,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syr<double>(const Layout, const Triangle,
const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syr<half>(const Layout, const Triangle,
const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Symmetric packed rank-1 matrix update: SSPR/DSPR/HSPR
template <typename T>
@@ -1521,9 +1589,11 @@ StatusCode Spr(const Layout layout, const Triangle triangle,
const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr ap_buffer, const size_t ap_offset,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xspr<T>(queue_cpp, event);
routine.DoSpr(layout, triangle,
n,
@@ -1538,19 +1608,19 @@ template StatusCode PUBLIC_API Spr<float>(const Layout, const Triangle,
const float,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Spr<double>(const Layout, const Triangle,
const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Spr<half>(const Layout, const Triangle,
const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Symmetric rank-2 matrix update: SSYR2/DSYR2/HSYR2
template <typename T>
@@ -1560,9 +1630,11 @@ StatusCode Syr2(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xsyr2<T>(queue_cpp, event);
routine.DoSyr2(layout, triangle,
n,
@@ -1579,21 +1651,21 @@ template StatusCode PUBLIC_API Syr2<float>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syr2<double>(const Layout, const Triangle,
const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syr2<half>(const Layout, const Triangle,
const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Symmetric packed rank-2 matrix update: SSPR2/DSPR2/HSPR2
template <typename T>
@@ -1603,9 +1675,11 @@ StatusCode Spr2(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr ap_buffer, const size_t ap_offset,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xspr2<T>(queue_cpp, event);
routine.DoSpr2(layout, triangle,
n,
@@ -1622,21 +1696,21 @@ template StatusCode PUBLIC_API Spr2<float>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Spr2<double>(const Layout, const Triangle,
const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Spr2<half>(const Layout, const Triangle,
const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// =================================================================================================
// BLAS level-3 (matrix-matrix) routines
@@ -1651,9 +1725,11 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xgemm<T>(queue_cpp, event);
routine.DoGemm(layout, a_transpose, b_transpose,
m, n, k,
@@ -1672,7 +1748,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double,
@@ -1680,7 +1756,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2,
@@ -1688,7 +1764,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2,
@@ -1696,7 +1772,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half,
@@ -1704,7 +1780,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
@@ -1715,9 +1791,11 @@ StatusCode Symm(const Layout layout, const Side side, const Triangle triangle,
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xsymm<T>(queue_cpp, event);
routine.DoSymm(layout, side, triangle,
m, n,
@@ -1736,7 +1814,7 @@ template StatusCode PUBLIC_API Symm<float>(const Layout, const Side, const Trian
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Symm<double>(const Layout, const Side, const Triangle,
const size_t, const size_t,
const double,
@@ -1744,7 +1822,7 @@ template StatusCode PUBLIC_API Symm<double>(const Layout, const Side, const Tria
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Symm<float2>(const Layout, const Side, const Triangle,
const size_t, const size_t,
const float2,
@@ -1752,7 +1830,7 @@ template StatusCode PUBLIC_API Symm<float2>(const Layout, const Side, const Tria
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Symm<double2>(const Layout, const Side, const Triangle,
const size_t, const size_t,
const double2,
@@ -1760,7 +1838,7 @@ template StatusCode PUBLIC_API Symm<double2>(const Layout, const Side, const Tri
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Symm<half>(const Layout, const Side, const Triangle,
const size_t, const size_t,
const half,
@@ -1768,7 +1846,7 @@ template StatusCode PUBLIC_API Symm<half>(const Layout, const Side, const Triang
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Hermitian matrix-matrix multiplication: CHEMM/ZHEMM
template <typename T>
@@ -1779,9 +1857,11 @@ StatusCode Hemm(const Layout layout, const Side side, const Triangle triangle,
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xhemm<T>(queue_cpp, event);
routine.DoHemm(layout, side, triangle,
m, n,
@@ -1800,7 +1880,7 @@ template StatusCode PUBLIC_API Hemm<float2>(const Layout, const Side, const Tria
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Hemm<double2>(const Layout, const Side, const Triangle,
const size_t, const size_t,
const double2,
@@ -1808,7 +1888,7 @@ template StatusCode PUBLIC_API Hemm<double2>(const Layout, const Side, const Tri
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK
template <typename T>
@@ -1818,9 +1898,11 @@ StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xsyrk<T>(queue_cpp, event);
routine.DoSyrk(layout, triangle, a_transpose,
n, k,
@@ -1837,35 +1919,35 @@ template StatusCode PUBLIC_API Syrk<float>(const Layout, const Triangle, const T
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syrk<double>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syrk<float2>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const float2,
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syrk<double2>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syrk<half>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Rank-K update of a hermitian matrix: CHERK/ZHERK
template <typename T>
@@ -1875,9 +1957,11 @@ StatusCode Herk(const Layout layout, const Triangle triangle, const Transpose a_
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xherk<std::complex<T>,T>(queue_cpp, event);
routine.DoHerk(layout, triangle, a_transpose,
n, k,
@@ -1894,14 +1978,14 @@ template StatusCode PUBLIC_API Herk<float>(const Layout, const Triangle, const T
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Herk<double>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K
template <typename T>
@@ -1912,9 +1996,11 @@ StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose a
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xsyr2k<T>(queue_cpp, event);
routine.DoSyr2k(layout, triangle, ab_transpose,
n, k,
@@ -1933,7 +2019,7 @@ template StatusCode PUBLIC_API Syr2k<float>(const Layout, const Triangle, const
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syr2k<double>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const double,
@@ -1941,7 +2027,7 @@ template StatusCode PUBLIC_API Syr2k<double>(const Layout, const Triangle, const
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syr2k<float2>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const float2,
@@ -1949,7 +2035,7 @@ template StatusCode PUBLIC_API Syr2k<float2>(const Layout, const Triangle, const
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syr2k<double2>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const double2,
@@ -1957,7 +2043,7 @@ template StatusCode PUBLIC_API Syr2k<double2>(const Layout, const Triangle, cons
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Syr2k<half>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const half,
@@ -1965,7 +2051,7 @@ template StatusCode PUBLIC_API Syr2k<half>(const Layout, const Triangle, const T
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Rank-2K update of a hermitian matrix: CHER2K/ZHER2K
template <typename T, typename U>
@@ -1976,9 +2062,11 @@ StatusCode Her2k(const Layout layout, const Triangle triangle, const Transpose a
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const U beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xher2k<T,U>(queue_cpp, event);
routine.DoHer2k(layout, triangle, ab_transpose,
n, k,
@@ -1997,7 +2085,7 @@ template StatusCode PUBLIC_API Her2k<float2,float>(const Layout, const Triangle,
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Her2k<double2,double>(const Layout, const Triangle, const Transpose,
const size_t, const size_t,
const double2,
@@ -2005,7 +2093,7 @@ template StatusCode PUBLIC_API Her2k<double2,double>(const Layout, const Triangl
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM
template <typename T>
@@ -2014,9 +2102,11 @@ StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, c
const T alpha,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xtrmm<T>(queue_cpp, event);
routine.DoTrmm(layout, side, triangle, a_transpose, diagonal,
m, n,
@@ -2031,31 +2121,31 @@ template StatusCode PUBLIC_API Trmm<float>(const Layout, const Side, const Trian
const float,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trmm<double>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trmm<float2>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const float2,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trmm<double2>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
template <typename T>
@@ -2064,9 +2154,11 @@ StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, c
const T alpha,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xtrsm<T>(queue_cpp, event);
routine.DoTrsm(layout, side, triangle, a_transpose, diagonal,
m, n,
@@ -2081,25 +2173,25 @@ template StatusCode PUBLIC_API Trsm<float>(const Layout, const Side, const Trian
const float,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trsm<double>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trsm<float2>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const float2,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Trsm<double2>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// =================================================================================================
// Extra non-BLAS routines (level-X)
@@ -2112,9 +2204,11 @@ StatusCode Omatcopy(const Layout layout, const Transpose a_transpose,
const T alpha,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xomatcopy<T>(queue_cpp, event);
routine.DoOmatcopy(layout, a_transpose,
m, n,
@@ -2129,40 +2223,42 @@ template StatusCode PUBLIC_API Omatcopy<float>(const Layout, const Transpose,
const float,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Omatcopy<double>(const Layout, const Transpose,
const size_t, const size_t,
const double,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Omatcopy<float2>(const Layout, const Transpose,
const size_t, const size_t,
const float2,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Omatcopy<double2>(const Layout, const Transpose,
const size_t, const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
const size_t, const size_t,
const half,
const CUdeviceptr, const size_t, const size_t,
CUdeviceptr, const size_t, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T>
StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const CUdeviceptr im_buffer, const size_t im_offset,
CUdeviceptr col_buffer, const size_t col_offset,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xim2col<T>(queue_cpp, event);
routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
Buffer<T>(im_buffer), im_offset,
@@ -2173,23 +2269,23 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width
template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED
template <typename T>
@@ -2198,9 +2294,11 @@ StatusCode AxpyBatched(const size_t n,
const CUdeviceptr x_buffer, const size_t *x_offsets, const size_t x_inc,
CUdeviceptr y_buffer, const size_t *y_offsets, const size_t y_inc,
const size_t batch_count,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = XaxpyBatched<T>(queue_cpp, event);
auto alphas_cpp = std::vector<T>();
auto x_offsets_cpp = std::vector<size_t>();
@@ -2223,31 +2321,31 @@ template StatusCode PUBLIC_API AxpyBatched<float>(const size_t,
const CUdeviceptr, const size_t*, const size_t,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API AxpyBatched<double>(const size_t,
const double*,
const CUdeviceptr, const size_t*, const size_t,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API AxpyBatched<float2>(const size_t,
const float2*,
const CUdeviceptr, const size_t*, const size_t,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API AxpyBatched<double2>(const size_t,
const double2*,
const CUdeviceptr, const size_t*, const size_t,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API AxpyBatched<half>(const size_t,
const half*,
const CUdeviceptr, const size_t*, const size_t,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
template <typename T>
@@ -2259,9 +2357,11 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
const T *betas,
CUdeviceptr c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
- CUstream* stream) {
+ const CUcontext context, const CUdevice device) {
try {
- auto queue_cpp = Queue(*queue);
+ const auto context_cpp = Context(context);
+ const auto device_cpp = Device(device);
+ auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = XgemmBatched<T>(queue_cpp, event);
auto alphas_cpp = std::vector<T>();
auto betas_cpp = std::vector<T>();
@@ -2294,7 +2394,7 @@ template StatusCode PUBLIC_API GemmBatched<float>(const Layout, const Transpose,
const float*,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API GemmBatched<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double*,
@@ -2303,7 +2403,7 @@ template StatusCode PUBLIC_API GemmBatched<double>(const Layout, const Transpose
const double*,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API GemmBatched<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2*,
@@ -2312,7 +2412,7 @@ template StatusCode PUBLIC_API GemmBatched<float2>(const Layout, const Transpose
const float2*,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API GemmBatched<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2*,
@@ -2321,7 +2421,7 @@ template StatusCode PUBLIC_API GemmBatched<double2>(const Layout, const Transpos
const double2*,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half*,
@@ -2330,7 +2430,7 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
const half*,
CUdeviceptr, const size_t*, const size_t,
const size_t,
- CUstream*);
+ const CUcontext, const CUdevice);
// =================================================================================================
} // namespace clblast
diff --git a/src/utilities/buffer_test.hpp b/src/utilities/buffer_test.hpp
index a5b6be4b..fd071434 100644
--- a/src/utilities/buffer_test.hpp
+++ b/src/utilities/buffer_test.hpp
@@ -15,7 +15,7 @@
#ifndef CLBLAST_BUFFER_TEST_H_
#define CLBLAST_BUFFER_TEST_H_
-#include "utilities/utilities.hpp
+#include "utilities/utilities.hpp"
namespace clblast {
// =================================================================================================