summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-06 16:08:27 +0100
committerGitHub <noreply@github.com>2018-01-06 16:08:27 +0100
commita7ccce196915db7a3b7ea7fe8ea9048f5b1204c6 (patch)
tree27dd8771ee6f913b5a2dabfae115bbe7fbc9d979
parent8040a4e355bdf6531eb9c4c5ae1fe4f792899d24 (diff)
parentad197da08da7ef414db90dbb97e92c575363c280 (diff)
Merge pull request #238 from CNugteren/gemm_api_with_temp_buffer
GEMM API with optional temp buffer
-rw-r--r--ROADMAP.md5
-rw-r--r--doc/clblast.md3
-rw-r--r--include/clblast.h15
-rw-r--r--include/clblast_cuda.h15
-rwxr-xr-xscripts/generator/generator.py116
-rw-r--r--scripts/generator/generator/cpp.py11
-rw-r--r--scripts/generator/generator/routine.py13
-rw-r--r--src/clblast.cpp71
-rw-r--r--src/clblast_cuda.cpp70
-rw-r--r--src/clpp11.hpp9
-rw-r--r--src/cupp11.hpp6
-rw-r--r--src/database/database.cpp38
-rw-r--r--src/database/database.hpp2
-rw-r--r--src/routine.cpp20
-rw-r--r--src/routine.hpp20
-rw-r--r--src/routines/level3/xgemm.cpp107
-rw-r--r--src/routines/level3/xgemm.hpp134
-rw-r--r--test/correctness/testblas.hpp2
-rw-r--r--test/performance/client.cpp2
-rw-r--r--test/performance/client.hpp2
-rw-r--r--test/routines/level1/xamax.hpp2
-rw-r--r--test/routines/level1/xasum.hpp2
-rw-r--r--test/routines/level1/xaxpy.hpp2
-rw-r--r--test/routines/level1/xcopy.hpp2
-rw-r--r--test/routines/level1/xdot.hpp2
-rw-r--r--test/routines/level1/xdotc.hpp2
-rw-r--r--test/routines/level1/xdotu.hpp2
-rw-r--r--test/routines/level1/xnrm2.hpp2
-rw-r--r--test/routines/level1/xscal.hpp2
-rw-r--r--test/routines/level1/xswap.hpp2
-rw-r--r--test/routines/level2/xgbmv.hpp2
-rw-r--r--test/routines/level2/xgemv.hpp2
-rw-r--r--test/routines/level2/xger.hpp2
-rw-r--r--test/routines/level2/xgerc.hpp2
-rw-r--r--test/routines/level2/xgeru.hpp2
-rw-r--r--test/routines/level2/xhbmv.hpp2
-rw-r--r--test/routines/level2/xhemv.hpp2
-rw-r--r--test/routines/level2/xher.hpp2
-rw-r--r--test/routines/level2/xher2.hpp2
-rw-r--r--test/routines/level2/xhpmv.hpp2
-rw-r--r--test/routines/level2/xhpr.hpp2
-rw-r--r--test/routines/level2/xhpr2.hpp2
-rw-r--r--test/routines/level2/xsbmv.hpp2
-rw-r--r--test/routines/level2/xspmv.hpp2
-rw-r--r--test/routines/level2/xspr.hpp2
-rw-r--r--test/routines/level2/xspr2.hpp2
-rw-r--r--test/routines/level2/xsymv.hpp2
-rw-r--r--test/routines/level2/xsyr.hpp2
-rw-r--r--test/routines/level2/xsyr2.hpp2
-rw-r--r--test/routines/level2/xtbmv.hpp2
-rw-r--r--test/routines/level2/xtpmv.hpp2
-rw-r--r--test/routines/level2/xtrmv.hpp2
-rw-r--r--test/routines/level2/xtrsv.hpp2
-rw-r--r--test/routines/level3/xgemm.hpp39
-rw-r--r--test/routines/level3/xhemm.hpp2
-rw-r--r--test/routines/level3/xher2k.hpp2
-rw-r--r--test/routines/level3/xherk.hpp2
-rw-r--r--test/routines/level3/xsymm.hpp2
-rw-r--r--test/routines/level3/xsyr2k.hpp2
-rw-r--r--test/routines/level3/xsyrk.hpp2
-rw-r--r--test/routines/level3/xtrmm.hpp2
-rw-r--r--test/routines/level3/xtrsm.hpp2
-rw-r--r--test/routines/levelx/xaxpybatched.hpp2
-rw-r--r--test/routines/levelx/xgemmbatched.hpp2
-rw-r--r--test/routines/levelx/xim2col.hpp2
-rw-r--r--test/routines/levelx/xinvert.hpp2
-rw-r--r--test/routines/levelx/xomatcopy.hpp2
67 files changed, 538 insertions, 254 deletions
diff --git a/ROADMAP.md b/ROADMAP.md
index 18ac0bc5..0488d048 100644
--- a/ROADMAP.md
+++ b/ROADMAP.md
@@ -10,8 +10,9 @@ This file gives an overview of the main features planned for addition to CLBlast
| [#181](https://github.com/CNugteren/CLBlast/issues/181) & #201 | Nov '17 | CNugteren | ✔ | Compilation for Android and testing on a device |
| - | Nov '17 | CNugteren | ✔ | Integration of CLTune for easy testing on Android / fewer dependencies |
| [#128](https://github.com/CNugteren/CLBlast/issues/128) & #205 | Nov-Dec '17 | CNugteren | ✔ | Pre-processor for loop unrolling and array-to-register-promotion for e.g. ARM Mali |
-| [#207](https://github.com/CNugteren/CLBlast/issues/207) | Dec '17 | CNugteren | | Tuning of the TRSM/TRSV routines |
-| [#195](https://github.com/CNugteren/CLBlast/issues/195) | Jan '18 | CNugteren | | Extra GEMM API with pre-allocated temporary buffer |
+| [#207](https://github.com/CNugteren/CLBlast/issues/207) | Dec '17 | CNugteren | ✔ | Tuning of the TRSM/TRSV routines |
+| [#195](https://github.com/CNugteren/CLBlast/issues/195) | Jan '18 | CNugteren | ✔ | Extra GEMM API with pre-allocated temporary buffer |
+| [#233](https://github.com/CNugteren/CLBlast/issues/233) | Jan '18 | CNugteren | | Add CLBlast to common package managers |
| [#224](https://github.com/CNugteren/CLBlast/issues/224) | Jan-Feb '18 | CNugteren | | Implement Hadamard product (element-wise vector-vector product) |
| [#223](https://github.com/CNugteren/CLBlast/issues/223) | Feb '18 | CNugteren | | Python OpenCL interface |
| [#169](https://github.com/CNugteren/CLBlast/issues/169) | ?? | dividiti | | Problem-specific tuning parameter selection |
diff --git a/doc/clblast.md b/doc/clblast.md
index 88563bc1..5ee601f5 100644
--- a/doc/clblast.md
+++ b/doc/clblast.md
@@ -2208,7 +2208,8 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
- cl_command_queue* queue, cl_event* event)
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer = nullptr)
```
C API:
diff --git a/include/clblast.h b/include/clblast.h
index e073b211..a05b487f 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -97,6 +97,7 @@ enum class StatusCode {
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
+ kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small
kInvalidBatchCount = -2049, // The batch count needs to be positive
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
@@ -520,7 +521,8 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
- cl_command_queue* queue, cl_event* event = nullptr);
+ cl_command_queue* queue, cl_event* event = nullptr,
+ cl_mem temp_buffer = nullptr);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
@@ -647,6 +649,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
// =================================================================================================
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t& temp_buffer_size);
+
+// =================================================================================================
+
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
StatusCode PUBLIC_API ClearCache();
diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h
index e28f68e5..e1237936 100644
--- a/include/clblast_cuda.h
+++ b/include/clblast_cuda.h
@@ -69,6 +69,7 @@ enum class StatusCode {
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
+ kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small
kInvalidBatchCount = -2049, // The batch count needs to be positive
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
@@ -492,7 +493,8 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- const CUcontext context, const CUdevice device);
+ const CUcontext context, const CUdevice device,
+ CUdeviceptr temp_buffer = 0);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
@@ -619,6 +621,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
// =================================================================================================
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ const CUdevice device, size_t& temp_buffer_size);
+
+// =================================================================================================
+
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
StatusCode PUBLIC_API ClearCache();
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 520e3fc8..5fbce2c4 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -46,8 +46,8 @@ FILES = [
"/include/clblast_cuda.h",
"/src/clblast_cuda.cpp",
]
-HEADER_LINES = [122, 21, 126, 24, 29, 41, 29, 65, 32, 94, 21]
-FOOTER_LINES = [25, 3, 27, 38, 6, 6, 6, 9, 2, 25, 3]
+HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
+FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 36, 55]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
@@ -109,71 +109,71 @@ col = "height * width * channels"
im2col_constants = ["channels", "height", "width", "kernel_h", "kernel_w", "pad_h", "pad_w", "stride_h", "stride_w", "dilation_h", "dilation_w"]
ROUTINES = [
[ # Level 1: vector-vector
- Routine(False, True, False, "1", "rotg", T, [S,D], [], [], [], ["sa","sb","sc","ss"], ["1","1","1","1"], [], "", "Generate givens plane rotation", "", []),
- Routine(False, True, False, "1", "rotmg", T, [S,D], [], [], ["sy1"], ["sd1","sd2","sx1","sparam"], ["1","1","1","1","1"], [], "", "Generate modified givens plane rotation", "", []),
- Routine(False, True, False, "1", "rot", T, [S,D], ["n"], [], [], ["x","y"], [xn,yn], ["cos","sin"],"", "Apply givens plane rotation", "", []),
- Routine(False, True, False, "1", "rotm", T, [S,D], ["n"], [], [], ["x","y","sparam"], [xn,yn,"1"], [], "", "Apply modified givens plane rotation", "", []),
- Routine(True, True, False, "1", "swap", T, [S,D,C,Z,H], ["n"], [], [], ["x","y"], [xn,yn], [], "", "Swap two vectors", "Interchanges _n_ elements of vectors _x_ and _y_.", []),
- Routine(True, True, False, "1", "scal", T, [S,D,C,Z,H], ["n"], [], [], ["x"], [xn], ["alpha"], "", "Vector scaling", "Multiplies _n_ elements of vector _x_ by a scalar constant _alpha_.", []),
- Routine(True, True, False, "1", "copy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], [], "", "Vector copy", "Copies the contents of vector _x_ into vector _y_.", []),
- Routine(True, True, False, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation _y = alpha * x + y_, in which _x_ and _y_ are vectors and _alpha_ is a scalar constant.", []),
- Routine(True, True, False, "1", "dot", T, [S,D,H], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two vectors", "Multiplies _n_ elements of the vectors _x_ and _y_ element-wise and accumulates the results. The sum is stored in the _dot_ buffer.", []),
- Routine(True, True, False, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
- Routine(True, True, False, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
- Routine(True, True, False, "1", "nrm2", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["nrm2"], [xn,"1"], [], "2*n", "Euclidian norm of a vector", "Accumulates the square of _n_ elements in the _x_ vector and takes the square root. The resulting L2 norm is stored in the _nrm2_ buffer.", []),
- Routine(True, True, False, "1", "asum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["asum"], [xn,"1"], [], "n", "Absolute sum of values in a vector", "Accumulates the absolute value of _n_ elements in the _x_ vector. The results are stored in the _asum_ buffer.", []),
- Routine(True, False, False, "1", "sum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["sum"], [xn,"1"], [], "n", "Sum of values in a vector (non-BLAS function)", "Accumulates the values of _n_ elements in the _x_ vector. The results are stored in the _sum_ buffer. This routine is the non-absolute version of the xASUM BLAS routine.", []),
- Routine(True, True, False, "1", "amax", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of absolute maximum value in a vector", "Finds the index of the maximum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer.", []),
- Routine(True, False, False, "1", "amin", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of absolute minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer.", []),
- Routine(True, False, False, "1", "max", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of maximum value in a vector (non-BLAS function)", "Finds the index of the maximum of the values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer. This routine is the non-absolute version of the IxAMAX BLAS routine.", []),
- Routine(True, False, False, "1", "min", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer. This routine is the non-absolute minimum version of the IxAMAX BLAS routine.", []),
+ Routine(False, True, False, False, "1", "rotg", T, [S,D], [], [], [], ["sa","sb","sc","ss"], ["1","1","1","1"], [], "", "Generate givens plane rotation", "", []),
+ Routine(False, True, False, False, "1", "rotmg", T, [S,D], [], [], ["sy1"], ["sd1","sd2","sx1","sparam"], ["1","1","1","1","1"], [], "", "Generate modified givens plane rotation", "", []),
+ Routine(False, True, False, False, "1", "rot", T, [S,D], ["n"], [], [], ["x","y"], [xn,yn], ["cos","sin"],"", "Apply givens plane rotation", "", []),
+ Routine(False, True, False, False, "1", "rotm", T, [S,D], ["n"], [], [], ["x","y","sparam"], [xn,yn,"1"], [], "", "Apply modified givens plane rotation", "", []),
+ Routine(True, True, False, False, "1", "swap", T, [S,D,C,Z,H], ["n"], [], [], ["x","y"], [xn,yn], [], "", "Swap two vectors", "Interchanges _n_ elements of vectors _x_ and _y_.", []),
+ Routine(True, True, False, False, "1", "scal", T, [S,D,C,Z,H], ["n"], [], [], ["x"], [xn], ["alpha"], "", "Vector scaling", "Multiplies _n_ elements of vector _x_ by a scalar constant _alpha_.", []),
+ Routine(True, True, False, False, "1", "copy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], [], "", "Vector copy", "Copies the contents of vector _x_ into vector _y_.", []),
+ Routine(True, True, False, False, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation _y = alpha * x + y_, in which _x_ and _y_ are vectors and _alpha_ is a scalar constant.", []),
+ Routine(True, True, False, False, "1", "dot", T, [S,D,H], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two vectors", "Multiplies _n_ elements of the vectors _x_ and _y_ element-wise and accumulates the results. The sum is stored in the _dot_ buffer.", []),
+ Routine(True, True, False, False, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
+ Routine(True, True, False, False, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
+ Routine(True, True, False, False, "1", "nrm2", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["nrm2"], [xn,"1"], [], "2*n", "Euclidian norm of a vector", "Accumulates the square of _n_ elements in the _x_ vector and takes the square root. The resulting L2 norm is stored in the _nrm2_ buffer.", []),
+ Routine(True, True, False, False, "1", "asum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["asum"], [xn,"1"], [], "n", "Absolute sum of values in a vector", "Accumulates the absolute value of _n_ elements in the _x_ vector. The results are stored in the _asum_ buffer.", []),
+ Routine(True, False, False, False, "1", "sum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["sum"], [xn,"1"], [], "n", "Sum of values in a vector (non-BLAS function)", "Accumulates the values of _n_ elements in the _x_ vector. The results are stored in the _sum_ buffer. This routine is the non-absolute version of the xASUM BLAS routine.", []),
+ Routine(True, True, False, False, "1", "amax", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of absolute maximum value in a vector", "Finds the index of the maximum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer.", []),
+ Routine(True, False, False, False, "1", "amin", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of absolute minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer.", []),
+ Routine(True, False, False, False, "1", "max", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of maximum value in a vector (non-BLAS function)", "Finds the index of the maximum of the values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer. This routine is the non-absolute version of the IxAMAX BLAS routine.", []),
+ Routine(True, False, False, False, "1", "min", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer. This routine is the non-absolute minimum version of the IxAMAX BLAS routine.", []),
],
[ # Level 2: matrix-vector
- Routine(True, True, False, "2a", "gemv", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General matrix-vector multiplication", "Performs the operation _y = alpha * A * x + beta * y_, in which _x_ is an input vector, _y_ is an input and output vector, _A_ is an input matrix, and _alpha_ and _beta_ are scalars. The matrix _A_ can optionally be transposed before performing the operation.", [ald_m]),
- Routine(True, True, False, "2a", "gbmv", T, [S,D,C,Z,H], ["m","n","kl","ku"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is banded instead.", [ald_kl_ku_one]),
- Routine(True, True, False, "2a", "hemv", T, [C,Z], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian matrix instead.", [ald_n]),
- Routine(True, True, False, "2a", "hbmv", T, [C,Z], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian banded matrix instead.", [ald_k_one]),
- Routine(True, True, False, "2a", "hpmv", T, [C,Z], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Hermitian packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
- Routine(True, True, False, "2a", "symv", T, [S,D,H], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric instead.", [ald_n]),
- Routine(True, True, False, "2a", "sbmv", T, [S,D,H], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric and banded instead.", [ald_k_one]),
- Routine(True, True, False, "2a", "spmv", T, [S,D,H], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Symmetric packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
- Routine(True, True, False, "2a", "trmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular instead.", [ald_n]),
- Routine(True, True, False, "2a", "tbmv", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular and banded instead.", [ald_k_one]),
- Routine(True, True, False, "2a", "tpmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "n", "Triangular packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a triangular packed matrix instead and repreented as _AP_.", []),
- Routine(True, True, False, "2a", "trsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a triangular system of equations", "", []),
- Routine(False, True, False, "2a", "tbsv", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a banded triangular system of equations", "", [ald_k_one]),
- Routine(False, True, False, "2a", "tpsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "", "Solves a packed triangular system of equations", "", []),
+ Routine(True, True, False, False, "2a", "gemv", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General matrix-vector multiplication", "Performs the operation _y = alpha * A * x + beta * y_, in which _x_ is an input vector, _y_ is an input and output vector, _A_ is an input matrix, and _alpha_ and _beta_ are scalars. The matrix _A_ can optionally be transposed before performing the operation.", [ald_m]),
+ Routine(True, True, False, False, "2a", "gbmv", T, [S,D,C,Z,H], ["m","n","kl","ku"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is banded instead.", [ald_kl_ku_one]),
+ Routine(True, True, False, False, "2a", "hemv", T, [C,Z], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian matrix instead.", [ald_n]),
+ Routine(True, True, False, False, "2a", "hbmv", T, [C,Z], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian banded matrix instead.", [ald_k_one]),
+ Routine(True, True, False, False, "2a", "hpmv", T, [C,Z], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Hermitian packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, False, False, "2a", "symv", T, [S,D,H], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric instead.", [ald_n]),
+ Routine(True, True, False, False, "2a", "sbmv", T, [S,D,H], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric and banded instead.", [ald_k_one]),
+ Routine(True, True, False, False, "2a", "spmv", T, [S,D,H], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Symmetric packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, False, False, "2a", "trmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular instead.", [ald_n]),
+ Routine(True, True, False, False, "2a", "tbmv", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular and banded instead.", [ald_k_one]),
+ Routine(True, True, False, False, "2a", "tpmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "n", "Triangular packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a triangular packed matrix instead and repreented as _AP_.", []),
+ Routine(True, True, False, False, "2a", "trsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a triangular system of equations", "", []),
+ Routine(False, True, False, False, "2a", "tbsv", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a banded triangular system of equations", "", [ald_k_one]),
+ Routine(False, True, False, False, "2a", "tpsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "", "Solves a packed triangular system of equations", "", []),
# Level 2: matrix update
- Routine(True, True, False, "2b", "ger", T, [S,D,H], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 matrix update", "Performs the operation _A = alpha * x * y^T + A_, in which _x_ is an input vector, _y^T_ is the transpose of the input vector _y_, _A_ is the matrix to be updated, and _alpha_ is a scalar value.", [ald_m]),
- Routine(True, True, False, "2b", "geru", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex matrix update", "Same operation as xGER, but with complex data-types.", [ald_m]),
- Routine(True, True, False, "2b", "gerc", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex conjugated matrix update", "Same operation as xGERU, but the update is done based on the complex conjugate of the input vectors.", [ald_m]),
- Routine(True, True, False, "2b", "her", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Hermitian rank-1 matrix update", "Performs the operation _A = alpha * x * x^T + A_, in which x is an input vector, x^T is the transpose of this vector, _A_ is the triangular Hermetian matrix to be updated, and alpha is a scalar value.", [ald_n]),
- Routine(True, True, False, "2b", "hpr", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Hermitian packed rank-1 matrix update", "Same operation as xHER, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
- Routine(True, True, False, "2b", "her2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Hermitian rank-2 matrix update", "Performs the operation _A = alpha * x * y^T + conj(alpha) * y * x^T + A_, in which _x_ is an input vector and _x^T_ its transpose, _y_ is an input vector and _y^T_ its transpose, _A_ is the triangular Hermetian matrix to be updated, _alpha_ is a scalar value and _conj(alpha)_ its complex conjugate.", [ald_n]),
- Routine(True, True, False, "2b", "hpr2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Hermitian packed rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
- Routine(True, True, False, "2b", "syr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Symmetric rank-1 matrix update", "Same operation as xHER, but matrix A is a symmetric matrix instead.", [ald_n]),
- Routine(True, True, False, "2b", "spr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Symmetric packed rank-1 matrix update", "Same operation as xSPR, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
- Routine(True, True, False, "2b", "syr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Symmetric rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is a symmetric matrix instead.", [ald_n]),
- Routine(True, True, False, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Symmetric packed rank-2 matrix update", "Same operation as xSPR2, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, False, False, "2b", "ger", T, [S,D,H], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 matrix update", "Performs the operation _A = alpha * x * y^T + A_, in which _x_ is an input vector, _y^T_ is the transpose of the input vector _y_, _A_ is the matrix to be updated, and _alpha_ is a scalar value.", [ald_m]),
+ Routine(True, True, False, False, "2b", "geru", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex matrix update", "Same operation as xGER, but with complex data-types.", [ald_m]),
+ Routine(True, True, False, False, "2b", "gerc", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex conjugated matrix update", "Same operation as xGERU, but the update is done based on the complex conjugate of the input vectors.", [ald_m]),
+ Routine(True, True, False, False, "2b", "her", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Hermitian rank-1 matrix update", "Performs the operation _A = alpha * x * x^T + A_, in which x is an input vector, x^T is the transpose of this vector, _A_ is the triangular Hermetian matrix to be updated, and alpha is a scalar value.", [ald_n]),
+ Routine(True, True, False, False, "2b", "hpr", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Hermitian packed rank-1 matrix update", "Same operation as xHER, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, False, False, "2b", "her2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Hermitian rank-2 matrix update", "Performs the operation _A = alpha * x * y^T + conj(alpha) * y * x^T + A_, in which _x_ is an input vector and _x^T_ its transpose, _y_ is an input vector and _y^T_ its transpose, _A_ is the triangular Hermetian matrix to be updated, _alpha_ is a scalar value and _conj(alpha)_ its complex conjugate.", [ald_n]),
+ Routine(True, True, False, False, "2b", "hpr2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Hermitian packed rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, False, False, "2b", "syr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Symmetric rank-1 matrix update", "Same operation as xHER, but matrix A is a symmetric matrix instead.", [ald_n]),
+ Routine(True, True, False, False, "2b", "spr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Symmetric packed rank-1 matrix update", "Same operation as xSPR, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, False, False, "2b", "syr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Symmetric rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is a symmetric matrix instead.", [ald_n]),
+ Routine(True, True, False, False, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Symmetric packed rank-2 matrix update", "Same operation as xSPR2, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
],
[ # Level 3: matrix-matrix
- Routine(True, True, False, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "General matrix-matrix multiplication", "Performs the matrix product _C = alpha * A * B + beta * C_, in which _A_ (_m_ by _k_) and _B_ (_k_ by _n_) are two general rectangular input matrices, _C_ (_m_ by _n_) is the matrix to be updated, and _alpha_ and _beta_ are scalar values. The matrices _A_ and/or _B_ can optionally be transposed before performing the operation.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
- Routine(True, True, False, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "Same operation as xGEMM, but _A_ is symmetric instead. In case of `side == kLeft`, _A_ is a symmetric _m_ by _m_ matrix and _C = alpha * A * B + beta * C_ is performed. Otherwise, in case of `side == kRight`, _A_ is a symmtric _n_ by _n_ matrix and _C = alpha * B * A + beta * C_ is performed.", [ald_side_m_n, bld_m, cld_m]),
- Routine(True, True, False, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "Same operation as xSYMM, but _A_ is an Hermitian matrix instead.", [ald_side_m_n, bld_m, cld_m]),
- Routine(True, True, False, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * A^T + beta * C_ or _C = alpha * A^T * A + beta * C_, in which _A_ is a general matrix and _A^T_ is its transpose, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, cld_m]),
- Routine(True, True, False, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "Same operation as xSYRK, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, cld_m]),
- Routine(True, True, False, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
- Routine(True, True, False, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
- Routine(True, True, False, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
- Routine(True, True, False, "3", "trsm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []),
+ Routine(True, True, False, True, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "General matrix-matrix multiplication", "Performs the matrix product _C = alpha * A * B + beta * C_, in which _A_ (_m_ by _k_) and _B_ (_k_ by _n_) are two general rectangular input matrices, _C_ (_m_ by _n_) is the matrix to be updated, and _alpha_ and _beta_ are scalar values. The matrices _A_ and/or _B_ can optionally be transposed before performing the operation.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
+ Routine(True, True, False, False, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "Same operation as xGEMM, but _A_ is symmetric instead. In case of `side == kLeft`, _A_ is a symmetric _m_ by _m_ matrix and _C = alpha * A * B + beta * C_ is performed. Otherwise, in case of `side == kRight`, _A_ is a symmtric _n_ by _n_ matrix and _C = alpha * B * A + beta * C_ is performed.", [ald_side_m_n, bld_m, cld_m]),
+ Routine(True, True, False, False, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "Same operation as xSYMM, but _A_ is an Hermitian matrix instead.", [ald_side_m_n, bld_m, cld_m]),
+ Routine(True, True, False, False, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * A^T + beta * C_ or _C = alpha * A^T * A + beta * C_, in which _A_ is a general matrix and _A^T_ is its transpose, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, cld_m]),
+ Routine(True, True, False, False, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "Same operation as xSYRK, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, cld_m]),
+ Routine(True, True, False, False, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
+ Routine(True, True, False, False, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
+ Routine(True, True, False, False, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
+ Routine(True, True, False, False, "3", "trsm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []),
],
[ # Level X: extra routines (not part of BLAS)
# Special routines:
- Routine(True, True, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
- Routine(True, True, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix.", []),
+ Routine(True, True, False, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
+ Routine(True, True, False, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix.", []),
# Batched routines:
- Routine(True, True, True, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
- Routine(True, True, True, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
+ Routine(True, True, True, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
+ Routine(True, True, True, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
]]
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index 2d18655f..656253d7 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -48,7 +48,7 @@ def clblast_cc(routine, cuda=False):
indent1 = " " * (15 + routine.length())
result = NL + "// " + routine.description + ": " + routine.short_names() + NL
if routine.implemented:
- result += routine.routine_header_cpp(12, "", cuda) + " {" + NL
+ result += routine.routine_header_cpp(12, "", cuda, implementation=True) + " {" + NL
result += " try {" + NL
if cuda:
result += " const auto context_cpp = Context(context);" + NL
@@ -60,8 +60,13 @@ def clblast_cc(routine, cuda=False):
result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL
if routine.batched:
result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
+ if routine.temp_buffer:
+ result += " const auto temp_buffer_provided = temp_buffer != nullptr;\n"
+ result += " auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);\n"
result += " routine.Do" + routine.capitalized_name() + "("
result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()])
+ if routine.temp_buffer:
+ result += ",\n" + indent1 + "temp_buffer_cpp, temp_buffer_provided"
result += ");" + NL
result += " return StatusCode::kSuccess;" + NL
result += " } catch (...) { return DispatchException(); }" + NL
@@ -79,8 +84,12 @@ def clblast_cc(routine, cuda=False):
result += "," + NL + indent2
if cuda:
result += "const CUcontext, const CUdevice"
+ if routine.temp_buffer:
+ result += ", CUdeviceptr"
else:
result += "cl_command_queue*, cl_event*"
+ if routine.temp_buffer:
+ result += ", cl_mem"
result += ");" + NL
return result
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index b6b55821..22be02b0 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -12,12 +12,13 @@ import generator.convert as convert
class Routine:
"""Class holding routine-specific information (e.g. name, which arguments, which precisions)"""
- def __init__(self, implemented, has_tests, batched, level, name, template, flavours, sizes, options,
+ def __init__(self, implemented, has_tests, batched, temp_buffer, level, name, template, flavours, sizes, options,
inputs, outputs, buffer_sizes, scalars, scratch,
description, details, requirements):
self.implemented = implemented
self.has_tests = has_tests
self.batched = batched
+ self.temp_buffer = temp_buffer
self.level = level
self.name = name
self.template = template
@@ -802,12 +803,14 @@ class Routine:
"""Retrieves a list of routine requirements for documentation"""
return self.requirements
- def routine_header_cpp(self, spaces, default_event, cuda=False):
+ def routine_header_cpp(self, spaces, default_event, cuda=False, implementation=False):
"""Retrieves the C++ templated definition for a routine"""
indent = " " * (spaces + self.length())
arguments = self.arguments_def(self.template)
+ mem_type = "cl_mem"
if cuda:
- arguments = [a.replace("cl_mem", "CUdeviceptr") for a in arguments]
+ arguments = [a.replace(mem_type, "CUdeviceptr") for a in arguments]
+ mem_type = "CUdeviceptr"
result = "template <" + self.template.name + ">\n"
result += "StatusCode " + self.capitalized_name() + "("
result += (",\n" + indent).join([a for a in arguments])
@@ -816,6 +819,10 @@ class Routine:
result += "const CUcontext context, const CUdevice device"
else:
result += "cl_command_queue* queue, cl_event* event" + default_event
+ if self.temp_buffer:
+ result += ",\n" + indent + mem_type + " temp_buffer"
+ if not implementation:
+ result += " = nullptr"
result += ")"
return result
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 7d2c2cef..f5e2f1be 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -1651,17 +1651,21 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
- cl_command_queue* queue, cl_event* event) {
+ cl_command_queue* queue, cl_event* event,
+ cl_mem temp_buffer) {
try {
auto queue_cpp = Queue(*queue);
auto routine = Xgemm<T>(queue_cpp, event);
+ const auto temp_buffer_provided = temp_buffer != nullptr;
+ auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);
routine.DoGemm(layout, a_transpose, b_transpose,
m, n, k,
alpha,
Buffer<T>(a_buffer), a_offset, a_ld,
Buffer<T>(b_buffer), b_offset, b_ld,
beta,
- Buffer<T>(c_buffer), c_offset, c_ld);
+ Buffer<T>(c_buffer), c_offset, c_ld,
+ temp_buffer_cpp, temp_buffer_provided);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
@@ -1672,7 +1676,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const float,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double,
@@ -1680,7 +1684,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const double,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2,
@@ -1688,7 +1692,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const
const cl_mem, const size_t, const size_t,
const float2,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2,
@@ -1696,7 +1700,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons
const cl_mem, const size_t, const size_t,
const double2,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half,
@@ -1704,7 +1708,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T
const cl_mem, const size_t, const size_t,
const half,
cl_mem, const size_t, const size_t,
- cl_command_queue*, cl_event*);
+ cl_command_queue*, cl_event*, cl_mem);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
@@ -2333,4 +2337,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
cl_command_queue*, cl_event*);
// =================================================================================================
+
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t& temp_buffer_size) {
+ try {
+
+ // Retrieves the tuning database
+ const auto queue_cpp = Queue(*queue);
+ const auto device = queue_cpp.GetDevice();
+ const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
+ Databases db(kernel_names);
+ Routine::InitDatabase(device, kernel_names, PrecisionValue<T>(), {}, db);
+
+ // Computes the buffer size
+ if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
+ temp_buffer_size = 0;
+ }
+ else {
+ temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
+ a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
+ db["MWG"], db["NWG"], db["KWG"]);
+ }
+ temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, cl_command_queue*, size_t&);
+
+// =================================================================================================
} // namespace clblast
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp
index 0e3d949d..348ff3f5 100644
--- a/src/clblast_cuda.cpp
+++ b/src/clblast_cuda.cpp
@@ -1725,19 +1725,23 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
- const CUcontext context, const CUdevice device) {
+ const CUcontext context, const CUdevice device,
+ CUdeviceptr temp_buffer) {
try {
const auto context_cpp = Context(context);
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = Xgemm<T>(queue_cpp, nullptr);
+ const auto temp_buffer_provided = temp_buffer != 0;
+ auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(0);
routine.DoGemm(layout, a_transpose, b_transpose,
m, n, k,
alpha,
Buffer<T>(a_buffer), a_offset, a_ld,
Buffer<T>(b_buffer), b_offset, b_ld,
beta,
- Buffer<T>(c_buffer), c_offset, c_ld);
+ Buffer<T>(c_buffer), c_offset, c_ld,
+ temp_buffer_cpp, temp_buffer_provided);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
@@ -1748,7 +1752,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double,
@@ -1756,7 +1760,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2,
@@ -1764,7 +1768,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const
const CUdeviceptr, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2,
@@ -1772,7 +1776,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons
const CUdeviceptr, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half,
@@ -1780,7 +1784,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T
const CUdeviceptr, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t,
- const CUcontext, const CUdevice);
+ const CUcontext, const CUdevice, CUdeviceptr);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
@@ -2433,4 +2437,56 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
const CUcontext, const CUdevice);
// =================================================================================================
+
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ const CUdevice device, size_t& temp_buffer_size) {
+ try {
+
+ // Retrieves the tuning database
+ const auto device_cpp = Device(device);
+ const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
+ Databases db(kernel_names);
+ Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue<T>(), {}, db);
+
+ // Computes the buffer size
+ if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
+ temp_buffer_size = 0;
+ }
+ else {
+ temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
+ a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
+ db["MWG"], db["NWG"], db["KWG"]);
+ }
+ temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, const CUdevice, size_t&);
+
+// =================================================================================================
} // namespace clblast
diff --git a/src/clpp11.hpp b/src/clpp11.hpp
index 6ebf1322..2119f26b 100644
--- a/src/clpp11.hpp
+++ b/src/clpp11.hpp
@@ -614,10 +614,11 @@ class Buffer {
}
// Regular constructor with memory management. If this class does not own the buffer object, then
- // the memory will not be freed automatically afterwards.
+ // the memory will not be freed automatically afterwards. If the size is set to 0, this will
+ // become a stub containing a nullptr
explicit Buffer(const Context &context, const BufferAccess access, const size_t size):
- buffer_(new cl_mem, [access](cl_mem* m) {
- if (access != BufferAccess::kNotOwned) { CheckError(clReleaseMemObject(*m)); }
+ buffer_(new cl_mem, [access, size](cl_mem* m) {
+ if (access != BufferAccess::kNotOwned && size > 0) { CheckError(clReleaseMemObject(*m)); }
delete m;
}),
access_(access) {
@@ -625,7 +626,7 @@ class Buffer {
if (access_ == BufferAccess::kReadOnly) { flags = CL_MEM_READ_ONLY; }
if (access_ == BufferAccess::kWriteOnly) { flags = CL_MEM_WRITE_ONLY; }
auto status = CL_SUCCESS;
- *buffer_ = clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status);
+ *buffer_ = (size > 0) ? clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status) : nullptr;
CLCudaAPIError::Check(status, "clCreateBuffer");
}
diff --git a/src/cupp11.hpp b/src/cupp11.hpp
index eb177ca2..509ae3e8 100644
--- a/src/cupp11.hpp
+++ b/src/cupp11.hpp
@@ -549,12 +549,12 @@ public:
// Regular constructor with memory management. If this class does not own the buffer object, then
// the memory will not be freed automatically afterwards.
explicit Buffer(const Context &, const BufferAccess access, const size_t size):
- buffer_(new CUdeviceptr, [access](CUdeviceptr* m) {
- if (access != BufferAccess::kNotOwned) { CheckError(cuMemFree(*m)); }
+ buffer_(new CUdeviceptr, [access, size](CUdeviceptr* m) {
+ if (access != BufferAccess::kNotOwned && size > 0) { CheckError(cuMemFree(*m)); }
delete m;
}),
access_(access) {
- CheckError(cuMemAlloc(buffer_.get(), size*sizeof(T)));
+ if (size > 0) { CheckError(cuMemAlloc(buffer_.get(), size*sizeof(T))); }
}
// As above, but now with read/write access as a default
diff --git a/src/database/database.cpp b/src/database/database.cpp
index ed56c65d..b2f70e49 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -39,6 +39,7 @@
namespace clblast {
// =================================================================================================
+std::vector<database::DatabaseEntry> Database::database = std::vector<database::DatabaseEntry>{};
const std::vector<database::DatabaseEntry> Database::apple_cpu_fallback = std::vector<database::DatabaseEntry>{
database::XaxpyApple, database::XdotApple,
database::XgemvApple, database::XgemvFastApple, database::XgemvFastRotApple, database::XgerApple, database::XtrsvApple,
@@ -58,23 +59,26 @@ Database::Database(const Device &device, const std::string &kernel_name,
const Precision precision, const std::vector<database::DatabaseEntry> &overlay):
parameters_(std::make_shared<database::Parameters>()) {
- database = std::vector<database::DatabaseEntry>{
- database::XaxpyHalf, database::XaxpySingle, database::XaxpyDouble, database::XaxpyComplexSingle, database::XaxpyComplexDouble,
- database::XdotHalf, database::XdotSingle, database::XdotDouble, database::XdotComplexSingle, database::XdotComplexDouble,
- database::XgemvHalf, database::XgemvSingle, database::XgemvDouble, database::XgemvComplexSingle, database::XgemvComplexDouble,
- database::XgemvFastHalf, database::XgemvFastSingle, database::XgemvFastDouble, database::XgemvFastComplexSingle, database::XgemvFastComplexDouble,
- database::XgemvFastRotHalf, database::XgemvFastRotSingle, database::XgemvFastRotDouble, database::XgemvFastRotComplexSingle, database::XgemvFastRotComplexDouble,
- database::XgerHalf, database::XgerSingle, database::XgerDouble, database::XgerComplexSingle, database::XgerComplexDouble,
- database::XgemmHalf, database::XgemmSingle, database::XgemmDouble, database::XgemmComplexSingle, database::XgemmComplexDouble,
- database::XgemmDirectHalf, database::XgemmDirectSingle, database::XgemmDirectDouble, database::XgemmDirectComplexSingle, database::XgemmDirectComplexDouble,
- database::CopyHalf, database::CopySingle, database::CopyDouble, database::CopyComplexSingle, database::CopyComplexDouble,
- database::PadHalf, database::PadSingle, database::PadDouble, database::PadComplexSingle, database::PadComplexDouble,
- database::TransposeHalf, database::TransposeSingle, database::TransposeDouble, database::TransposeComplexSingle, database::TransposeComplexDouble,
- database::PadtransposeHalf, database::PadtransposeSingle, database::PadtransposeDouble, database::PadtransposeComplexSingle, database::PadtransposeComplexDouble,
- database::InvertHalf, database::InvertSingle, database::InvertDouble, database::InvertComplexSingle, database::InvertComplexDouble,
- database::GemmRoutineHalf, database::GemmRoutineSingle, database::GemmRoutineDouble, database::GemmRoutineComplexSingle, database::GemmRoutineComplexDouble,
- database::TrsvRoutineHalf, database::TrsvRoutineSingle, database::TrsvRoutineDouble, database::TrsvRoutineComplexSingle, database::TrsvRoutineComplexDouble
- };
+ // Initializes the static variable on first use. At this point we are sure all global variables are initialized
+ if (database.size() == 0) {
+ database = std::vector<database::DatabaseEntry>{
+ database::XaxpyHalf, database::XaxpySingle, database::XaxpyDouble, database::XaxpyComplexSingle, database::XaxpyComplexDouble,
+ database::XdotHalf, database::XdotSingle, database::XdotDouble, database::XdotComplexSingle, database::XdotComplexDouble,
+ database::XgemvHalf, database::XgemvSingle, database::XgemvDouble, database::XgemvComplexSingle, database::XgemvComplexDouble,
+ database::XgemvFastHalf, database::XgemvFastSingle, database::XgemvFastDouble, database::XgemvFastComplexSingle, database::XgemvFastComplexDouble,
+ database::XgemvFastRotHalf, database::XgemvFastRotSingle, database::XgemvFastRotDouble, database::XgemvFastRotComplexSingle, database::XgemvFastRotComplexDouble,
+ database::XgerHalf, database::XgerSingle, database::XgerDouble, database::XgerComplexSingle, database::XgerComplexDouble,
+ database::XgemmHalf, database::XgemmSingle, database::XgemmDouble, database::XgemmComplexSingle, database::XgemmComplexDouble,
+ database::XgemmDirectHalf, database::XgemmDirectSingle, database::XgemmDirectDouble, database::XgemmDirectComplexSingle, database::XgemmDirectComplexDouble,
+ database::CopyHalf, database::CopySingle, database::CopyDouble, database::CopyComplexSingle, database::CopyComplexDouble,
+ database::PadHalf, database::PadSingle, database::PadDouble, database::PadComplexSingle, database::PadComplexDouble,
+ database::TransposeHalf, database::TransposeSingle, database::TransposeDouble, database::TransposeComplexSingle, database::TransposeComplexDouble,
+ database::PadtransposeHalf, database::PadtransposeSingle, database::PadtransposeDouble, database::PadtransposeComplexSingle, database::PadtransposeComplexDouble,
+ database::InvertHalf, database::InvertSingle, database::InvertDouble, database::InvertComplexSingle, database::InvertComplexDouble,
+ database::GemmRoutineHalf, database::GemmRoutineSingle, database::GemmRoutineDouble, database::GemmRoutineComplexSingle, database::GemmRoutineComplexDouble,
+ database::TrsvRoutineHalf, database::TrsvRoutineSingle, database::TrsvRoutineDouble, database::TrsvRoutineComplexSingle, database::TrsvRoutineComplexDouble
+ };
+ }
// Finds device information
const auto device_type = GetDeviceType(device);
diff --git a/src/database/database.hpp b/src/database/database.hpp
index de4306bc..8e53e013 100644
--- a/src/database/database.hpp
+++ b/src/database/database.hpp
@@ -35,7 +35,7 @@ class Database {
static const std::string kDeviceVendorAll;
// The database consists of separate database entries, stored together in a vector
- std::vector<database::DatabaseEntry> database;
+ static std::vector<database::DatabaseEntry> database;
// Database for a special case: Apple CPUs support limited number of threads
static const std::vector<database::DatabaseEntry> apple_cpu_fallback;
diff --git a/src/routine.cpp b/src/routine.cpp
index 5a1c0fe9..fa5934f6 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -62,28 +62,10 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
device_(queue_.GetDevice()),
db_(kernel_names) {
- InitDatabase(userDatabase);
+ InitDatabase(device_, kernel_names, precision, userDatabase, db_);
InitProgram(source);
}
-void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatabase) {
- const auto platform_id = device_.PlatformID();
- for (const auto &kernel_name : kernel_names_) {
-
- // Queries the cache to see whether or not the kernel parameter database is already there
- bool has_db;
- db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_id, device_(), precision_, kernel_name },
- &has_db);
- if (has_db) { continue; }
-
- // Builds the parameter database for this device and routine set and stores it in the cache
- log_debug("Searching database for kernel '" + kernel_name + "'");
- db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
- DatabaseCache::Instance().Store(DatabaseKey{ platform_id, device_(), precision_, kernel_name },
- Database{ db_(kernel_name) });
- }
-}
-
void Routine::InitProgram(std::initializer_list<const char *> source) {
// Determines the identifier for this particular routine call
diff --git a/src/routine.hpp b/src/routine.hpp
index a8f1cb6a..00f7b5cc 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -33,6 +33,26 @@ namespace clblast {
class Routine {
public:
+ static void InitDatabase(const Device &device, const std::vector<std::string> &kernel_names,
+ const Precision precision, const std::vector<database::DatabaseEntry> &userDatabase,
+ Databases &db) {
+ const auto platform_id = device.PlatformID();
+ for (const auto &kernel_name : kernel_names) {
+
+ // Queries the cache to see whether or not the kernel parameter database is already there
+ bool has_db;
+ db(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device(), precision, kernel_name},
+ &has_db);
+ if (has_db) { continue; }
+
+ // Builds the parameter database for this device and routine set and stores it in the cache
+ log_debug("Searching database for kernel '" + kernel_name + "'");
+ db(kernel_name) = Database(device, kernel_name, precision, userDatabase);
+ DatabaseCache::Instance().Store(DatabaseKey{platform_id, device(), precision, kernel_name},
+ Database{db(kernel_name)});
+ }
+ }
+
// Base class constructor. The user database is an optional extra database to override the
// built-in database.
// All heavy preparation work is done inside this constructor.
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index edba1f00..4c1b9558 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -19,6 +19,11 @@
namespace clblast {
// =================================================================================================
+// Defines the assumptions of the GEMM kernels
+template <typename T> const bool Xgemm<T>::a_want_rotated_ = false;
+template <typename T> const bool Xgemm<T>::b_want_rotated_ = true;
+template <typename T> const bool Xgemm<T>::c_want_rotated_ = false;
+
// Constructor: forwards to base class constructor
template <typename T>
Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name):
@@ -56,40 +61,15 @@ void Xgemm<T>::DoGemm(const Layout layout,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
- const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld) {
-
- // Makes sure all dimensions are larger than zero
- if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
-
- // Computes whether or not the matrices are transposed in memory. This is based on their layout
- // (row or column-major) and whether or not they are requested to be pre-transposed. Note
- // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of
- // col-major) to be transformed, so transposing requirements are not the same as whether or not
- // the matrix is actually transposed in memory.
- const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
- (layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
- const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
- (layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
- const auto c_rotated = (layout == Layout::kRowMajor);
- static const auto a_want_rotated = false;
- static const auto b_want_rotated = true;
- static const auto c_want_rotated = false;
- const auto a_do_transpose = a_rotated != a_want_rotated;
- const auto b_do_transpose = b_rotated != b_want_rotated;
- const auto c_do_transpose = c_rotated != c_want_rotated;
-
- // In case of complex data-types, the transpose can also become a conjugate transpose
- const auto a_conjugate = (a_transpose == Transpose::kConjugate);
- const auto b_conjugate = (b_transpose == Transpose::kConjugate);
-
- // Computes the first and second dimensions of the 3 matrices taking into account whether the
- // matrices are rotated or not
- const auto a_one = (a_rotated) ? k : m;
- const auto a_two = (a_rotated) ? m : k;
- const auto b_one = (b_rotated) ? n : k;
- const auto b_two = (b_rotated) ? k : n;
- const auto c_one = (c_rotated) ? n : m;
- const auto c_two = (c_rotated) ? m : n;
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { // optional arguments
+
+ // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
+ bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
+ size_t a_one, a_two, b_one, b_two, c_one, c_two;
+ ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
// Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and
// their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the
@@ -103,11 +83,7 @@ void Xgemm<T>::DoGemm(const Layout layout,
TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld);
// Selects which version of GEMM to run
- const auto m_n_k = static_cast<unsigned long long>(m) * static_cast<unsigned long long>(n) *
- static_cast<unsigned long long>(k);
- const auto database_value = static_cast<unsigned long long>(db_["XGEMM_MIN_INDIRECT_SIZE"]);
- const auto min_indirect_size = database_value * database_value * database_value;
- const auto do_gemm_direct = (m_n_k < min_indirect_size);
+ const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
if (do_gemm_direct) { // for small sizes (single kernel)
GemmDirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
@@ -119,9 +95,8 @@ void Xgemm<T>::DoGemm(const Layout layout,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
c_buffer, c_offset, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
- a_one, a_two, a_want_rotated,
- b_one, b_two, b_want_rotated,
- c_one, c_two, c_want_rotated);
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ temp_buffer, temp_buffer_provided);
}
}
@@ -139,9 +114,11 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
- const size_t a_one, const size_t a_two, const bool a_want_rotated,
- const size_t b_one, const size_t b_two, const bool b_want_rotated,
- const size_t c_one, const size_t c_two, const bool c_want_rotated) {
+ const size_t a_one, const size_t a_two,
+ const size_t b_one, const size_t b_two,
+ const size_t c_one, const size_t c_two,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided) {
+
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(m, db_["MWG"]);
const auto n_ceiled = Ceil(n, db_["NWG"]);
@@ -149,39 +126,39 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
// Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
// whether the matrices need to be rotated or not for the kernel.
- const auto a_one_i = (a_want_rotated) ? k_ceiled : m_ceiled;
- const auto a_two_i = (a_want_rotated) ? m_ceiled : k_ceiled;
- const auto b_one_i = (b_want_rotated) ? n_ceiled : k_ceiled;
- const auto b_two_i = (b_want_rotated) ? k_ceiled : n_ceiled;
- const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled;
- const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled;
+ size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
+ CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
+ a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
// Determines whether or not temporary matrices are needed
- auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offset == 0 &&
- a_do_transpose == false && a_conjugate == false;
- auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offset == 0 &&
- b_do_transpose == false && b_conjugate == false;
- auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offset == 0 &&
- c_do_transpose == false;
+ auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate);
+ auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate);
+ auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false);
// Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
- auto temp_size = size_t{0};
auto b_temp_offset = size_t{0};
auto c_temp_offset = size_t{0};
- if (!a_no_temp) { temp_size += a_one_i*a_two_i; }
- if (!b_no_temp) { b_temp_offset = temp_size; temp_size += b_one_i*b_two_i; }
- if (!c_no_temp) { c_temp_offset = temp_size; temp_size += c_one_i*c_two_i; }
+ const auto temp_size = ComputeTempSize(a_no_temp, b_no_temp, c_no_temp,
+ a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i,
+ b_temp_offset, c_temp_offset);
if (!IsMultiple(b_temp_offset, db_["VWN"])) { throw BLASError(StatusCode::kUnexpectedError); }
if (!IsMultiple(c_temp_offset, db_["VWM"])) { throw BLASError(StatusCode::kUnexpectedError); }
// Creates the buffer for the (optional) temporary matrices. Note that we use 'a_buffer' in case
// when no temporary buffer is needed, but that's just to make it compile: it is never used.
- const auto temp_buffer = (temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer;
+ const auto temp_buffer_all = (temp_buffer_provided) ? temp_buffer :
+ ((temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer);
+
+ // Verifies if the provided temporary buffer is large enough
+ if (temp_buffer_provided) {
+ const auto required_size = temp_size * sizeof(T);
+ if (temp_buffer_all.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryTemp); }
+ }
// Sets the buffer pointers for (temp) matrices A, B, and C
- const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer;
- const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer;
- const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer;
+ const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer_all;
+ const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer_all;
+ const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer_all;
// Events of all kernels (including pre/post processing kernels)
auto eventWaitList = std::vector<Event>();
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp
index c61611b6..b51d1771 100644
--- a/src/routines/level3/xgemm.hpp
+++ b/src/routines/level3/xgemm.hpp
@@ -24,6 +24,130 @@ template <typename T>
class Xgemm: public Routine {
public:
+ // Defines the assumptions of the GEMM kernels
+ static const bool a_want_rotated_;
+ static const bool b_want_rotated_;
+ static const bool c_want_rotated_;
+
+ // Computes the size of the temporary GEMM buffer based on user-arguments
+ static size_t GetTempSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ const size_t mwg, const size_t nwg, const size_t kwg) {
+
+ // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
+ bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
+ size_t a_one, a_two, b_one, b_two, c_one, c_two;
+ ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
+
+ // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
+ // whether the matrices need to be rotated or not for the kernel.
+ size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
+ CalculateInternalDimensions(m, n, k, mwg, nwg, kwg,
+ a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
+
+ // Determines whether or not temporary matrices are needed
+ auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate);
+ auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate);
+ auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false);
+
+ // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
+ auto b_temp_offset = size_t{0};
+ auto c_temp_offset = size_t{0};
+ return ComputeTempSize(a_no_temp, b_no_temp, c_no_temp,
+ a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i,
+ b_temp_offset, c_temp_offset);
+ }
+
+ // Selects which version of GEMM to run
+ static bool UseDirectKernel(const size_t m, const size_t n, const size_t k,
+ const size_t min_indirect_size) {
+ const auto m_n_k = static_cast<unsigned long long>(m) * static_cast<unsigned long long>(n) *
+ static_cast<unsigned long long>(k);
+ const auto min_indirect_size_ll = static_cast<unsigned long long>(min_indirect_size);
+ const auto min_indirect_size_e3 = min_indirect_size_ll * min_indirect_size_ll * min_indirect_size_ll;
+ return (m_n_k < min_indirect_size_e3);
+ }
+
+ // Process the user-arguments, computes secondary parameters
+ static void ProcessArguments(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ size_t& a_one, size_t& a_two, size_t& b_one,
+ size_t& b_two, size_t& c_one, size_t& c_two,
+ bool& a_do_transpose, bool& b_do_transpose, bool& c_do_transpose,
+ bool& a_conjugate, bool& b_conjugate) {
+
+ // Makes sure all dimensions are larger than zero
+ if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
+
+ // Computes whether or not the matrices are transposed in memory. This is based on their layout
+ // (row or column-major) and whether or not they are requested to be pre-transposed. Note
+ // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of
+ // col-major) to be transformed, so transposing requirements are not the same as whether or not
+ // the matrix is actually transposed in memory.
+ const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
+ (layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
+ const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
+ (layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
+ const auto c_rotated = (layout == Layout::kRowMajor);
+ a_do_transpose = a_rotated != a_want_rotated_;
+ b_do_transpose = b_rotated != b_want_rotated_;
+ c_do_transpose = c_rotated != c_want_rotated_;
+
+ // In case of complex data-types, the transpose can also become a conjugate transpose
+ a_conjugate = (a_transpose == Transpose::kConjugate);
+ b_conjugate = (b_transpose == Transpose::kConjugate);
+
+ // Computes the first and second dimensions of the 3 matrices taking into account whether the
+ // matrices are rotated or not
+ a_one = (a_rotated) ? k : m;
+ a_two = (a_rotated) ? m : k;
+ b_one = (b_rotated) ? n : k;
+ b_two = (b_rotated) ? k : n;
+ c_one = (c_rotated) ? n : m;
+ c_two = (c_rotated) ? m : n;
+ }
+
+ // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
+ static size_t ComputeTempSize(const bool a_no_temp, const bool b_no_temp, const bool c_no_temp,
+ const size_t a_size, const size_t b_size, const size_t c_size,
+ size_t &b_temp_offset, size_t &c_temp_offset) {
+ auto temp_size = size_t{0};
+ if (!a_no_temp) { temp_size += a_size; }
+ if (!b_no_temp) { b_temp_offset = temp_size; temp_size += b_size; }
+ if (!c_no_temp) { c_temp_offset = temp_size; temp_size += c_size; }
+ return temp_size;
+ }
+
+ // Determines whether or not temporary matrices are needed
+ static bool NoTempBuffer(const size_t one, const size_t one_i, const size_t two, const size_t two_i,
+ const size_t ld, const size_t offset,
+ const bool do_transpose, const bool conjugate) {
+ return one == one_i && two == two_i && ld == one && offset == 0 && !do_transpose && !conjugate;
+ }
+
+
+ // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
+ // whether the matrices need to be rotated or not for the kernel.
+ static void CalculateInternalDimensions(const size_t m, const size_t n, const size_t k,
+ const size_t mwg, const size_t nwg, const size_t kwg,
+ size_t& a_one_i, size_t& a_two_i, size_t& b_one_i,
+ size_t& b_two_i, size_t& c_one_i, size_t& c_two_i) {
+ const auto m_ceiled = Ceil(m, mwg);
+ const auto n_ceiled = Ceil(n, nwg);
+ const auto k_ceiled = Ceil(k, kwg);
+ a_one_i = (a_want_rotated_) ? k_ceiled : m_ceiled;
+ a_two_i = (a_want_rotated_) ? m_ceiled : k_ceiled;
+ b_one_i = (b_want_rotated_) ? n_ceiled : k_ceiled;
+ b_two_i = (b_want_rotated_) ? k_ceiled : n_ceiled;
+ c_one_i = (c_want_rotated_) ? n_ceiled : m_ceiled;
+ c_two_i = (c_want_rotated_) ? m_ceiled : n_ceiled;
+ }
+
// Constructor
Xgemm(Queue &queue, EventPointer event, const std::string &name = "GEMM");
@@ -34,7 +158,8 @@ class Xgemm: public Routine {
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
const T beta,
- const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld);
+ const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
+ const Buffer<T> &temp_buffer = Buffer<T>(0), const bool temp_buffer_provided = false);
// Indirect version of GEMM (with pre and post-processing kernels)
void GemmIndirect(const size_t m, const size_t n, const size_t k,
@@ -45,9 +170,10 @@ class Xgemm: public Routine {
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
- const size_t a_one, const size_t a_two, const bool a_want_rotated,
- const size_t b_one, const size_t b_two, const bool b_want_rotated,
- const size_t c_one, const size_t c_two, const bool c_want_rotated);
+ const size_t a_one, const size_t a_two,
+ const size_t b_one, const size_t b_two,
+ const size_t c_one, const size_t c_two,
+ const Buffer<T> &temp_buffer, const bool temp_buffer_provided);
// Direct version of GEMM (no pre and post-processing kernels)
void GemmDirect(const size_t m, const size_t n, const size_t k,
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp
index 4e02fd28..54b2d6f8 100644
--- a/test/correctness/testblas.hpp
+++ b/test/correctness/testblas.hpp
@@ -350,7 +350,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
for (auto &dilation_h: dilation_hs) { r_args.dilation_h = dilation_h;
for (auto &dilation_w: dilation_ws) { r_args.dilation_w = dilation_w;
for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count;
- C::SetSizes(r_args);
+ C::SetSizes(r_args, tester.queue_);
regular_test_vector.push_back(r_args);
}
}
diff --git a/test/performance/client.cpp b/test/performance/client.cpp
index 8b18b9a9..83088223 100644
--- a/test/performance/client.cpp
+++ b/test/performance/client.cpp
@@ -214,7 +214,7 @@ void Client<T,U>::PerformanceTest(Arguments<U> &args, const SetMetric set_sizes)
while(true) {
// Sets the buffer sizes (routine-specific)
- set_sizes(args);
+ set_sizes(args, queue);
// Populates input host matrices with random data
std::vector<T> x_source(args.x_size);
diff --git a/test/performance/client.hpp b/test/performance/client.hpp
index 0b6176c8..eb224976 100644
--- a/test/performance/client.hpp
+++ b/test/performance/client.hpp
@@ -48,7 +48,7 @@ class Client {
using Reference1 = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
using Reference2 = std::function<StatusCode(const Arguments<U>&, BuffersHost<T>&, Queue&)>;
using Reference3 = std::function<StatusCode(const Arguments<U>&, BuffersCUDA<T>&, Queue&)>;
- using SetMetric = std::function<void(Arguments<U>&)>;
+ using SetMetric = std::function<void(Arguments<U>&, Queue&)>;
using GetMetric = std::function<size_t(const Arguments<U>&)>;
// The constructor
diff --git a/test/routines/level1/xamax.hpp b/test/routines/level1/xamax.hpp
index d74807c9..71c1a0ec 100644
--- a/test/routines/level1/xamax.hpp
+++ b/test/routines/level1/xamax.hpp
@@ -47,7 +47,7 @@ class TestXamax {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.scalar_size = GetSizeImax(args);
}
diff --git a/test/routines/level1/xasum.hpp b/test/routines/level1/xasum.hpp
index 573f1223..62ff895d 100644
--- a/test/routines/level1/xasum.hpp
+++ b/test/routines/level1/xasum.hpp
@@ -47,7 +47,7 @@ class TestXasum {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.scalar_size = GetSizeAsum(args);
}
diff --git a/test/routines/level1/xaxpy.hpp b/test/routines/level1/xaxpy.hpp
index 7491a9e8..16d21324 100644
--- a/test/routines/level1/xaxpy.hpp
+++ b/test/routines/level1/xaxpy.hpp
@@ -48,7 +48,7 @@ class TestXaxpy {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
}
diff --git a/test/routines/level1/xcopy.hpp b/test/routines/level1/xcopy.hpp
index 58abdbf4..ddef8529 100644
--- a/test/routines/level1/xcopy.hpp
+++ b/test/routines/level1/xcopy.hpp
@@ -47,7 +47,7 @@ class TestXcopy {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
}
diff --git a/test/routines/level1/xdot.hpp b/test/routines/level1/xdot.hpp
index 229d18c9..b668b2df 100644
--- a/test/routines/level1/xdot.hpp
+++ b/test/routines/level1/xdot.hpp
@@ -50,7 +50,7 @@ class TestXdot {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
args.scalar_size = GetSizeDot(args);
diff --git a/test/routines/level1/xdotc.hpp b/test/routines/level1/xdotc.hpp
index 9a1dc33a..8ef2a9b8 100644
--- a/test/routines/level1/xdotc.hpp
+++ b/test/routines/level1/xdotc.hpp
@@ -50,7 +50,7 @@ class TestXdotc {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
args.scalar_size = GetSizeDot(args);
diff --git a/test/routines/level1/xdotu.hpp b/test/routines/level1/xdotu.hpp
index 4b2c7647..cabdf274 100644
--- a/test/routines/level1/xdotu.hpp
+++ b/test/routines/level1/xdotu.hpp
@@ -50,7 +50,7 @@ class TestXdotu {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
args.scalar_size = GetSizeDot(args);
diff --git a/test/routines/level1/xnrm2.hpp b/test/routines/level1/xnrm2.hpp
index f3a789b5..22973e1b 100644
--- a/test/routines/level1/xnrm2.hpp
+++ b/test/routines/level1/xnrm2.hpp
@@ -47,7 +47,7 @@ class TestXnrm2 {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.scalar_size = GetSizeNrm2(args);
}
diff --git a/test/routines/level1/xscal.hpp b/test/routines/level1/xscal.hpp
index 95038032..34a3d7cf 100644
--- a/test/routines/level1/xscal.hpp
+++ b/test/routines/level1/xscal.hpp
@@ -45,7 +45,7 @@ class TestXscal {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level1/xswap.hpp b/test/routines/level1/xswap.hpp
index 58310698..61711872 100644
--- a/test/routines/level1/xswap.hpp
+++ b/test/routines/level1/xswap.hpp
@@ -47,7 +47,7 @@ class TestXswap {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
}
diff --git a/test/routines/level2/xgbmv.hpp b/test/routines/level2/xgbmv.hpp
index 7c198e5d..13fcb137 100644
--- a/test/routines/level2/xgbmv.hpp
+++ b/test/routines/level2/xgbmv.hpp
@@ -58,7 +58,7 @@ class TestXgbmv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xgemv.hpp b/test/routines/level2/xgemv.hpp
index 780e2976..56372ad8 100644
--- a/test/routines/level2/xgemv.hpp
+++ b/test/routines/level2/xgemv.hpp
@@ -58,7 +58,7 @@ class TestXgemv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xger.hpp b/test/routines/level2/xger.hpp
index 9c5e2e40..85e987cb 100644
--- a/test/routines/level2/xger.hpp
+++ b/test/routines/level2/xger.hpp
@@ -54,7 +54,7 @@ class TestXger {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xgerc.hpp b/test/routines/level2/xgerc.hpp
index 5f58b65d..49200b22 100644
--- a/test/routines/level2/xgerc.hpp
+++ b/test/routines/level2/xgerc.hpp
@@ -54,7 +54,7 @@ class TestXgerc {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xgeru.hpp b/test/routines/level2/xgeru.hpp
index fea3932c..67f9f510 100644
--- a/test/routines/level2/xgeru.hpp
+++ b/test/routines/level2/xgeru.hpp
@@ -54,7 +54,7 @@ class TestXgeru {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xhbmv.hpp b/test/routines/level2/xhbmv.hpp
index 0ccd69b7..4ee9c066 100644
--- a/test/routines/level2/xhbmv.hpp
+++ b/test/routines/level2/xhbmv.hpp
@@ -52,7 +52,7 @@ class TestXhbmv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xhemv.hpp b/test/routines/level2/xhemv.hpp
index 053bc2dc..b1fb0bee 100644
--- a/test/routines/level2/xhemv.hpp
+++ b/test/routines/level2/xhemv.hpp
@@ -52,7 +52,7 @@ class TestXhemv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xher.hpp b/test/routines/level2/xher.hpp
index 745df43f..ec183f7a 100644
--- a/test/routines/level2/xher.hpp
+++ b/test/routines/level2/xher.hpp
@@ -49,7 +49,7 @@ class TestXher {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<U> &args) {
+ static void SetSizes(Arguments<U> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level2/xher2.hpp b/test/routines/level2/xher2.hpp
index 794e9a1e..5f442ca4 100644
--- a/test/routines/level2/xher2.hpp
+++ b/test/routines/level2/xher2.hpp
@@ -52,7 +52,7 @@ class TestXher2 {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xhpmv.hpp b/test/routines/level2/xhpmv.hpp
index 157272d3..00f32244 100644
--- a/test/routines/level2/xhpmv.hpp
+++ b/test/routines/level2/xhpmv.hpp
@@ -52,7 +52,7 @@ class TestXhpmv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.ap_size = GetSizeAP(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xhpr.hpp b/test/routines/level2/xhpr.hpp
index a3bc60d1..1e9bbe29 100644
--- a/test/routines/level2/xhpr.hpp
+++ b/test/routines/level2/xhpr.hpp
@@ -49,7 +49,7 @@ class TestXhpr {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<U> &args) {
+ static void SetSizes(Arguments<U> &args, Queue&) {
args.ap_size = GetSizeAP(args);
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level2/xhpr2.hpp b/test/routines/level2/xhpr2.hpp
index 1aa6cc54..433a5a93 100644
--- a/test/routines/level2/xhpr2.hpp
+++ b/test/routines/level2/xhpr2.hpp
@@ -52,7 +52,7 @@ class TestXhpr2 {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.ap_size = GetSizeAP(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xsbmv.hpp b/test/routines/level2/xsbmv.hpp
index 51d6441e..12f1ca60 100644
--- a/test/routines/level2/xsbmv.hpp
+++ b/test/routines/level2/xsbmv.hpp
@@ -52,7 +52,7 @@ class TestXsbmv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xspmv.hpp b/test/routines/level2/xspmv.hpp
index f3089836..a8db7a7b 100644
--- a/test/routines/level2/xspmv.hpp
+++ b/test/routines/level2/xspmv.hpp
@@ -52,7 +52,7 @@ class TestXspmv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.ap_size = GetSizeAP(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xspr.hpp b/test/routines/level2/xspr.hpp
index d76de610..af17b8cd 100644
--- a/test/routines/level2/xspr.hpp
+++ b/test/routines/level2/xspr.hpp
@@ -49,7 +49,7 @@ class TestXspr {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.ap_size = GetSizeAP(args);
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level2/xspr2.hpp b/test/routines/level2/xspr2.hpp
index 5ce82a52..b615aca7 100644
--- a/test/routines/level2/xspr2.hpp
+++ b/test/routines/level2/xspr2.hpp
@@ -52,7 +52,7 @@ class TestXspr2 {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.ap_size = GetSizeAP(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xsymv.hpp b/test/routines/level2/xsymv.hpp
index 2a70756d..9c418f88 100644
--- a/test/routines/level2/xsymv.hpp
+++ b/test/routines/level2/xsymv.hpp
@@ -52,7 +52,7 @@ class TestXsymv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xsyr.hpp b/test/routines/level2/xsyr.hpp
index 02aad990..73818727 100644
--- a/test/routines/level2/xsyr.hpp
+++ b/test/routines/level2/xsyr.hpp
@@ -49,7 +49,7 @@ class TestXsyr {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level2/xsyr2.hpp b/test/routines/level2/xsyr2.hpp
index 492a9d2d..8cdfb305 100644
--- a/test/routines/level2/xsyr2.hpp
+++ b/test/routines/level2/xsyr2.hpp
@@ -52,7 +52,7 @@ class TestXsyr2 {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/level2/xtbmv.hpp b/test/routines/level2/xtbmv.hpp
index 587676ca..8e1e7610 100644
--- a/test/routines/level2/xtbmv.hpp
+++ b/test/routines/level2/xtbmv.hpp
@@ -48,7 +48,7 @@ class TestXtbmv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level2/xtpmv.hpp b/test/routines/level2/xtpmv.hpp
index 02f334a2..b28fbf65 100644
--- a/test/routines/level2/xtpmv.hpp
+++ b/test/routines/level2/xtpmv.hpp
@@ -48,7 +48,7 @@ class TestXtpmv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.ap_size = GetSizeAP(args);
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level2/xtrmv.hpp b/test/routines/level2/xtrmv.hpp
index 4f2dd582..4f4e3a55 100644
--- a/test/routines/level2/xtrmv.hpp
+++ b/test/routines/level2/xtrmv.hpp
@@ -48,7 +48,7 @@ class TestXtrmv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level2/xtrsv.hpp b/test/routines/level2/xtrsv.hpp
index 81508236..52adad38 100644
--- a/test/routines/level2/xtrsv.hpp
+++ b/test/routines/level2/xtrsv.hpp
@@ -48,7 +48,7 @@ class TestXtrsv {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.x_size = GetSizeX(args);
}
diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp
index fe8cf7b9..4cfa9c83 100644
--- a/test/routines/level3/xgemm.hpp
+++ b/test/routines/level3/xgemm.hpp
@@ -37,7 +37,8 @@ class TestXgemm {
kArgAOffset, kArgBOffset, kArgCOffset,
kArgAlpha, kArgBeta};
}
- static std::vector<std::string> BuffersIn() { return {kBufMatA, kBufMatB, kBufMatC}; }
+ static std::vector<std::string> BuffersIn() { return {kBufMatA, kBufMatB, kBufMatC,
+ kBufMatAP}; } // used as temp buffer
static std::vector<std::string> BuffersOut() { return {kBufMatC}; }
// Describes how to obtain the sizes of the buffers
@@ -60,10 +61,33 @@ class TestXgemm {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue &queue) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
args.c_size = GetSizeC(args);
+
+ // Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels
+ if (V != 0) {
+ const auto device = queue.GetDevice();
+ const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests
+ const auto override_status = OverrideParameters(device(), "GemmRoutine", PrecisionValue<T>(),
+ {{"XGEMM_MIN_INDIRECT_SIZE", switch_threshold}});
+ if (override_status != StatusCode::kSuccess) { }
+ }
+
+ // Sets the size of the temporary buffer (optional argument to GEMM)
+ auto temp_buffer_size = size_t{0};
+ #ifdef OPENCL_API
+ auto queue_plain = queue();
+ GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
+ args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
+ &queue_plain, temp_buffer_size);
+ #elif CUDA_API
+ GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
+ args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
+ queue.GetDevice()(), temp_buffer_size);
+ #endif
+ args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero
}
// Describes what the default values of the leading dimensions of the matrices are
@@ -83,13 +107,6 @@ class TestXgemm {
// Describes how to run the CLBlast routine
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
- if (V != 0) {
- const auto device = queue.GetDevice();
- const auto switch_threshold = (V == 1) ? size_t{0} : size_t{1024 * 1024 * 1024}; // large enough for tests
- const auto override_status = OverrideParameters(device(), "GemmRoutine", PrecisionValue<T>(),
- {{"XGEMM_MIN_INDIRECT_SIZE", switch_threshold}});
- if (override_status != StatusCode::kSuccess) { return override_status; }
- }
#ifdef OPENCL_API
auto queue_plain = queue();
auto event = cl_event{};
@@ -98,7 +115,7 @@ class TestXgemm {
buffers.a_mat(), args.a_offset, args.a_ld,
buffers.b_mat(), args.b_offset, args.b_ld, args.beta,
buffers.c_mat(), args.c_offset, args.c_ld,
- &queue_plain, &event);
+ &queue_plain, &event, buffers.ap_mat()); // temp buffer
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
#elif CUDA_API
auto status = Gemm(args.layout, args.a_transpose, args.b_transpose,
@@ -106,7 +123,7 @@ class TestXgemm {
buffers.a_mat(), args.a_offset, args.a_ld,
buffers.b_mat(), args.b_offset, args.b_ld, args.beta,
buffers.c_mat(), args.c_offset, args.c_ld,
- queue.GetContext()(), queue.GetDevice()());
+ queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer
cuStreamSynchronize(queue());
#endif
return status;
diff --git a/test/routines/level3/xhemm.hpp b/test/routines/level3/xhemm.hpp
index 3b70d3f1..13e685b9 100644
--- a/test/routines/level3/xhemm.hpp
+++ b/test/routines/level3/xhemm.hpp
@@ -60,7 +60,7 @@ class TestXhemm {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
args.c_size = GetSizeC(args);
diff --git a/test/routines/level3/xher2k.hpp b/test/routines/level3/xher2k.hpp
index 6c4e12f1..a8ca4d46 100644
--- a/test/routines/level3/xher2k.hpp
+++ b/test/routines/level3/xher2k.hpp
@@ -58,7 +58,7 @@ class TestXher2k {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<U> &args) {
+ static void SetSizes(Arguments<U> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
args.c_size = GetSizeC(args);
diff --git a/test/routines/level3/xherk.hpp b/test/routines/level3/xherk.hpp
index c1bb7a0b..3fe14cb2 100644
--- a/test/routines/level3/xherk.hpp
+++ b/test/routines/level3/xherk.hpp
@@ -52,7 +52,7 @@ class TestXherk {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<U> &args) {
+ static void SetSizes(Arguments<U> &args, Queue&) {
args.a_size = GetSizeA(args);
args.c_size = GetSizeC(args);
}
diff --git a/test/routines/level3/xsymm.hpp b/test/routines/level3/xsymm.hpp
index 90cc1888..837e45d8 100644
--- a/test/routines/level3/xsymm.hpp
+++ b/test/routines/level3/xsymm.hpp
@@ -60,7 +60,7 @@ class TestXsymm {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
args.c_size = GetSizeC(args);
diff --git a/test/routines/level3/xsyr2k.hpp b/test/routines/level3/xsyr2k.hpp
index 6b29aff7..bf9f3fbf 100644
--- a/test/routines/level3/xsyr2k.hpp
+++ b/test/routines/level3/xsyr2k.hpp
@@ -58,7 +58,7 @@ class TestXsyr2k {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
args.c_size = GetSizeC(args);
diff --git a/test/routines/level3/xsyrk.hpp b/test/routines/level3/xsyrk.hpp
index b7782176..23dcf12f 100644
--- a/test/routines/level3/xsyrk.hpp
+++ b/test/routines/level3/xsyrk.hpp
@@ -52,7 +52,7 @@ class TestXsyrk {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.c_size = GetSizeC(args);
}
diff --git a/test/routines/level3/xtrmm.hpp b/test/routines/level3/xtrmm.hpp
index 62d0f573..51377a16 100644
--- a/test/routines/level3/xtrmm.hpp
+++ b/test/routines/level3/xtrmm.hpp
@@ -52,7 +52,7 @@ class TestXtrmm {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
}
diff --git a/test/routines/level3/xtrsm.hpp b/test/routines/level3/xtrsm.hpp
index 9ce1f09c..66c8f415 100644
--- a/test/routines/level3/xtrsm.hpp
+++ b/test/routines/level3/xtrsm.hpp
@@ -53,7 +53,7 @@ class TestXtrsm {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
}
diff --git a/test/routines/levelx/xaxpybatched.hpp b/test/routines/levelx/xaxpybatched.hpp
index e9715f4e..9a09b47f 100644
--- a/test/routines/levelx/xaxpybatched.hpp
+++ b/test/routines/levelx/xaxpybatched.hpp
@@ -51,7 +51,7 @@ class TestXaxpyBatched {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
diff --git a/test/routines/levelx/xgemmbatched.hpp b/test/routines/levelx/xgemmbatched.hpp
index 2a8bd9d4..e237a018 100644
--- a/test/routines/levelx/xgemmbatched.hpp
+++ b/test/routines/levelx/xgemmbatched.hpp
@@ -71,7 +71,7 @@ class TestXgemmBatched {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
args.c_size = GetSizeC(args);
diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp
index e90537fa..092e251d 100644
--- a/test/routines/levelx/xim2col.hpp
+++ b/test/routines/levelx/xim2col.hpp
@@ -62,7 +62,7 @@ public:
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
}
diff --git a/test/routines/levelx/xinvert.hpp b/test/routines/levelx/xinvert.hpp
index b8503029..126856ac 100644
--- a/test/routines/levelx/xinvert.hpp
+++ b/test/routines/levelx/xinvert.hpp
@@ -149,7 +149,7 @@ class TestXinvert {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
}
diff --git a/test/routines/levelx/xomatcopy.hpp b/test/routines/levelx/xomatcopy.hpp
index 477d6da6..ea35dbe2 100644
--- a/test/routines/levelx/xomatcopy.hpp
+++ b/test/routines/levelx/xomatcopy.hpp
@@ -104,7 +104,7 @@ class TestXomatcopy {
}
// Describes how to set the sizes of all the buffers
- static void SetSizes(Arguments<T> &args) {
+ static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
}