summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG1
-rw-r--r--README.md22
-rw-r--r--doc/clblast.md37
-rw-r--r--include/clblast.h12
-rw-r--r--include/clblast_c.h55
-rw-r--r--include/internal/database.h2
-rw-r--r--include/internal/database/xgemm.h12
-rw-r--r--scripts/generator/generator.py18
-rw-r--r--src/clblast.cc55
-rw-r--r--src/clblast_c.cc114
-rw-r--r--src/database.cc2
-rw-r--r--src/routines/level3/xsymm.cc1
-rw-r--r--src/routines/level3/xsyr2k.cc2
-rw-r--r--src/routines/level3/xsyrk.cc2
-rw-r--r--src/routines/level3/xtrmm.cc1
15 files changed, 296 insertions, 40 deletions
diff --git a/CHANGELOG b/CHANGELOG
index 328f044a..bf21f58e 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -4,6 +4,7 @@ Development version (next release)
- Added half-precision routines:
* Level-1: HSWAP/HSCAL/HCOPY/HAXPY/HDOT/HNRM2/HASUM/HSUM/iHAMAX/iHMAX/iHMIN
* Level-2: HGEMV/HGBMV/HHEMV/HHBMV/HHPMV/HSYMV/HSBMV/HSPMV/HTRMV/HTBMV/HTPMV/HGER/HSYR/HSPR/HSYR2/HSPR2
+ * Level-3: HGEMM/HSYMM/HSYRK/HSYR2K/HTRMM
Version 0.7.1
- Improved performance of large power-of-2 xGEMM kernels for AMD GPUs
diff --git a/README.md b/README.md
index f1a50ff4..51c282a3 100644
--- a/README.md
+++ b/README.md
@@ -128,7 +128,7 @@ If your device is not (yet) among this list or if you want to tune CLBlast for s
cmake -DTUNERS=ON ..
-Note that CLBlast's tuners are based on the CLTune auto-tuning library, which has to be installed separately (version 2.3.0 or higher). CLTune is available from GitHub.
+Note that CLBlast's tuners are based on the CLTune auto-tuning library, which has to be installed separately (version 2.3.1 or higher). CLTune is available from GitHub.
Compiling with `-DTUNERS=ON` will generate a number of tuners, each named `clblast_tuner_xxxxx`, in which `xxxxx` corresponds to a `.opencl` kernel file as found in `src/kernels`. These kernels corresponds to routines (e.g. `xgemm`) or to common pre-processing or post-processing kernels (`copy` and `transpose`). Running such a tuner will test a number of parameter-value combinations on your device and report which one gave the best performance. Running `make alltuners` runs all tuners for all precisions in one go. You can set the default device and platform for `alltuners` by setting the `DEFAULT_DEVICE` and `DEFAULT_PLATFORM` environmental variables before running CMake.
@@ -224,16 +224,16 @@ CLBlast is in active development but already supports almost all the BLAS routin
| xSYR2 | ✔ | ✔ | - | - | ✔ |
| xSPR2 | ✔ | ✔ | - | - | ✔ |
-| Level-3 | S | D | C | Z |
-| ---------|---|---|---|---|
-| xGEMM | ✔ | ✔ | ✔ | ✔ |
-| xSYMM | ✔ | ✔ | ✔ | ✔ |
-| xHEMM | - | - | ✔ | ✔ |
-| xSYRK | ✔ | ✔ | ✔ | ✔ |
-| xHERK | - | - | ✔ | ✔ |
-| xSYR2K | ✔ | ✔ | ✔ | ✔ |
-| xHER2K | - | - | ✔ | ✔ |
-| xTRMM | ✔ | ✔ | ✔ | ✔ |
+| Level-3 | S | D | C | Z | H |
+| ---------|---|---|---|---|---|
+| xGEMM | ✔ | ✔ | ✔ | ✔ | ✔ |
+| xSYMM | ✔ | ✔ | ✔ | ✔ | ✔ |
+| xHEMM | - | - | ✔ | ✔ | - |
+| xSYRK | ✔ | ✔ | ✔ | ✔ | ✔ |
+| xHERK | - | - | ✔ | ✔ | - |
+| xSYR2K | ✔ | ✔ | ✔ | ✔ | ✔ |
+| xHER2K | - | - | ✔ | ✔ | - |
+| xTRMM | ✔ | ✔ | ✔ | ✔ | ✔ |
In addition, some non-BLAS routines are also supported by CLBlast. They are experimental and should be used with care:
diff --git a/doc/clblast.md b/doc/clblast.md
index 6f3f09c2..8dbb97e4 100644
--- a/doc/clblast.md
+++ b/doc/clblast.md
@@ -2075,6 +2075,14 @@ StatusCode CLBlastZgemm(const Layout layout, const Transpose a_transpose, const
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event)
+StatusCode CLBlastHgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event)
```
Arguments to GEMM:
@@ -2153,6 +2161,14 @@ StatusCode CLBlastZsymm(const Layout layout, const Side side, const Triangle tri
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event)
+StatusCode CLBlastHsymm(const Layout layout, const Side side, const Triangle triangle,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event)
```
Arguments to SYMM:
@@ -2286,6 +2302,13 @@ StatusCode CLBlastZsyrk(const Layout layout, const Triangle triangle, const Tran
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event)
+StatusCode CLBlastHsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
+ const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event)
```
Arguments to SYRK:
@@ -2415,6 +2438,14 @@ StatusCode CLBlastZsyr2k(const Layout layout, const Triangle triangle, const Tra
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event)
+StatusCode CLBlastHsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
+ const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event)
```
Arguments to SYR2K:
@@ -2543,6 +2574,12 @@ StatusCode CLBlastZtrmm(const Layout layout, const Side side, const Triangle tri
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
cl_command_queue* queue, cl_event* event)
+StatusCode CLBlastHtrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event)
```
Arguments to TRMM:
diff --git a/include/clblast.h b/include/clblast.h
index d7b952ba..64b2610a 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -466,7 +466,7 @@ StatusCode Spr2(const Layout layout, const Triangle triangle,
// BLAS level-3 (matrix-matrix) routines
// =================================================================================================
-// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM
+// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM
template <typename T>
StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
@@ -477,7 +477,7 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event = nullptr);
-// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM
+// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
StatusCode Symm(const Layout layout, const Side side, const Triangle triangle,
const size_t m, const size_t n,
@@ -499,7 +499,7 @@ StatusCode Hemm(const Layout layout, const Side side, const Triangle triangle,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event = nullptr);
-// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK
+// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK
template <typename T>
StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
const size_t n, const size_t k,
@@ -519,7 +519,7 @@ StatusCode Herk(const Layout layout, const Triangle triangle, const Transpose a_
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event = nullptr);
-// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K
+// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K
template <typename T>
StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
const size_t n, const size_t k,
@@ -541,7 +541,7 @@ StatusCode Her2k(const Layout layout, const Triangle triangle, const Transpose a
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event = nullptr);
-// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM
+// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM
template <typename T>
StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
const size_t m, const size_t n,
@@ -550,7 +550,7 @@ StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, c
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
cl_command_queue* queue, cl_event* event = nullptr);
-// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
+// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
template <typename T>
StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
const size_t m, const size_t n,
diff --git a/include/clblast_c.h b/include/clblast_c.h
index 92392921..40248615 100644
--- a/include/clblast_c.h
+++ b/include/clblast_c.h
@@ -986,7 +986,7 @@ StatusCode PUBLIC_API CLBlastHspr2(const Layout layout, const Triangle triangle,
// BLAS level-3 (matrix-matrix) routines
// =================================================================================================
-// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM
+// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM
StatusCode PUBLIC_API CLBlastSgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const float alpha,
@@ -1019,8 +1019,16 @@ StatusCode PUBLIC_API CLBlastZgemm(const Layout layout, const Transpose a_transp
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event);
+StatusCode PUBLIC_API CLBlastHgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event);
-// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM
+// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
StatusCode PUBLIC_API CLBlastSsymm(const Layout layout, const Side side, const Triangle triangle,
const size_t m, const size_t n,
const float alpha,
@@ -1053,6 +1061,14 @@ StatusCode PUBLIC_API CLBlastZsymm(const Layout layout, const Side side, const T
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event);
+StatusCode PUBLIC_API CLBlastHsymm(const Layout layout, const Side side, const Triangle triangle,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event);
// Hermitian matrix-matrix multiplication: CHEMM/ZHEMM
StatusCode PUBLIC_API CLBlastChemm(const Layout layout, const Side side, const Triangle triangle,
@@ -1072,7 +1088,7 @@ StatusCode PUBLIC_API CLBlastZhemm(const Layout layout, const Side side, const T
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event);
-// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK
+// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK
StatusCode PUBLIC_API CLBlastSsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
const size_t n, const size_t k,
const float alpha,
@@ -1101,6 +1117,13 @@ StatusCode PUBLIC_API CLBlastZsyrk(const Layout layout, const Triangle triangle,
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event);
+StatusCode PUBLIC_API CLBlastHsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
+ const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event);
// Rank-K update of a hermitian matrix: CHERK/ZHERK
StatusCode PUBLIC_API CLBlastCherk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
@@ -1118,7 +1141,7 @@ StatusCode PUBLIC_API CLBlastZherk(const Layout layout, const Triangle triangle,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event);
-// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K
+// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K
StatusCode PUBLIC_API CLBlastSsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
const size_t n, const size_t k,
const float alpha,
@@ -1151,6 +1174,14 @@ StatusCode PUBLIC_API CLBlastZsyr2k(const Layout layout, const Triangle triangle
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event);
+StatusCode PUBLIC_API CLBlastHsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
+ const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event);
// Rank-2K update of a hermitian matrix: CHER2K/ZHER2K
StatusCode PUBLIC_API CLBlastCher2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
@@ -1170,7 +1201,7 @@ StatusCode PUBLIC_API CLBlastZher2k(const Layout layout, const Triangle triangle
cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
cl_command_queue* queue, cl_event* event);
-// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM
+// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM
StatusCode PUBLIC_API CLBlastStrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
const size_t m, const size_t n,
const float alpha,
@@ -1195,8 +1226,14 @@ StatusCode PUBLIC_API CLBlastZtrmm(const Layout layout, const Side side, const T
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
cl_command_queue* queue, cl_event* event);
+StatusCode PUBLIC_API CLBlastHtrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event);
-// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
+// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
StatusCode PUBLIC_API CLBlastStrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
const size_t m, const size_t n,
const float alpha,
@@ -1221,6 +1258,12 @@ StatusCode PUBLIC_API CLBlastZtrsm(const Layout layout, const Side side, const T
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
cl_command_queue* queue, cl_event* event);
+StatusCode PUBLIC_API CLBlastHtrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event);
// =================================================================================================
diff --git a/include/internal/database.h b/include/internal/database.h
index 34629bf5..f93eaa22 100644
--- a/include/internal/database.h
+++ b/include/internal/database.h
@@ -71,7 +71,7 @@ class Database {
static const DatabaseEntry XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble;
static const DatabaseEntry XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble;
static const DatabaseEntry XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble;
- static const DatabaseEntry XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble;
+ static const DatabaseEntry XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble;
static const DatabaseEntry CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble;
static const DatabaseEntry PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble;
static const DatabaseEntry TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble;
diff --git a/include/internal/database/xgemm.h b/include/internal/database/xgemm.h
index 9ca2bff5..647188e9 100644
--- a/include/internal/database/xgemm.h
+++ b/include/internal/database/xgemm.h
@@ -14,6 +14,18 @@
namespace clblast {
// =================================================================================================
+const Database::DatabaseEntry Database::XgemmHalf = {
+ "Xgemm", Precision::kHalf, {
+ { // Default
+ kDeviceTypeAll, "default", {
+ { "default", { {"KWG",16}, {"KWI",2}, {"MDIMA",8}, {"MDIMC",8}, {"MWG",32}, {"NDIMB",8}, {"NDIMC",8}, {"NWG",64}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} } },
+ }
+ },
+ }
+};
+
+// =================================================================================================
+
const Database::DatabaseEntry Database::XgemmSingle = {
"Xgemm", Precision::kSingle, {
{ // AMD GPUs
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 8dd1f77a..7a8ff9f8 100644
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -109,15 +109,15 @@ routines = [
Routine(True, True, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], ["alpha"], "", "Symmetric packed rank-2 matrix update", "", []),
],
[ # Level 3: matrix-matrix
- Routine(True, True, "3", "gemm", T, [S,D,C,Z], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "General matrix-matrix multiplication", "", []),
- Routine(True, True, "3", "symm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "", []),
- Routine(True, True, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "", []),
- Routine(True, True, "3", "syrk", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "", []),
- Routine(True, True, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "", []),
- Routine(True, True, "3", "syr2k", T, [S,D,C,Z], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "", []),
- Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "", []),
- Routine(True, True, "3", "trmm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Triangular matrix-matrix multiplication", "", []),
- Routine(False, True, "3", "trsm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Solves a triangular system of equations", "", []),
+ Routine(True, True, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "General matrix-matrix multiplication", "", []),
+ Routine(True, True, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "", []),
+ Routine(True, True, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "", []),
+ Routine(True, True, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "", []),
+ Routine(True, True, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "", []),
+ Routine(True, True, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "", []),
+ Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "", []),
+ Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Triangular matrix-matrix multiplication", "", []),
+ Routine(False, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Solves a triangular system of equations", "", []),
]]
# ==================================================================================================
diff --git a/src/clblast.cc b/src/clblast.cc
index 449c7321..07322327 100644
--- a/src/clblast.cc
+++ b/src/clblast.cc
@@ -1613,7 +1613,7 @@ template StatusCode PUBLIC_API Spr2<half>(const Layout, const Triangle,
// BLAS level-3 (matrix-matrix) routines
// =================================================================================================
-// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM
+// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM
template <typename T>
StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
@@ -1667,8 +1667,16 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons
const double2,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const half,
+ const cl_mem, const size_t, const size_t,
+ const cl_mem, const size_t, const size_t,
+ const half,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
-// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM
+// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T>
StatusCode Symm(const Layout layout, const Side side, const Triangle triangle,
const size_t m, const size_t n,
@@ -1722,6 +1730,14 @@ template StatusCode PUBLIC_API Symm<double2>(const Layout, const Side, const Tri
const double2,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Symm<half>(const Layout, const Side, const Triangle,
+ const size_t, const size_t,
+ const half,
+ const cl_mem, const size_t, const size_t,
+ const cl_mem, const size_t, const size_t,
+ const half,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
// Hermitian matrix-matrix multiplication: CHEMM/ZHEMM
template <typename T>
@@ -1762,7 +1778,7 @@ template StatusCode PUBLIC_API Hemm<double2>(const Layout, const Side, const Tri
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
-// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK
+// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK
template <typename T>
StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
const size_t n, const size_t k,
@@ -1810,6 +1826,13 @@ template StatusCode PUBLIC_API Syrk<double2>(const Layout, const Triangle, const
const double2,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Syrk<half>(const Layout, const Triangle, const Transpose,
+ const size_t, const size_t,
+ const half,
+ const cl_mem, const size_t, const size_t,
+ const half,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
// Rank-K update of a hermitian matrix: CHERK/ZHERK
template <typename T>
@@ -1846,7 +1869,7 @@ template StatusCode PUBLIC_API Herk<double>(const Layout, const Triangle, const
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
-// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K
+// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K
template <typename T>
StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
const size_t n, const size_t k,
@@ -1900,6 +1923,14 @@ template StatusCode PUBLIC_API Syr2k<double2>(const Layout, const Triangle, cons
const double2,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Syr2k<half>(const Layout, const Triangle, const Transpose,
+ const size_t, const size_t,
+ const half,
+ const cl_mem, const size_t, const size_t,
+ const cl_mem, const size_t, const size_t,
+ const half,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
// Rank-2K update of a hermitian matrix: CHER2K/ZHER2K
template <typename T, typename U>
@@ -1940,7 +1971,7 @@ template StatusCode PUBLIC_API Her2k<double2,double>(const Layout, const Triangl
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
-// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM
+// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM
template <typename T>
StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
const size_t m, const size_t n,
@@ -1982,8 +2013,14 @@ template StatusCode PUBLIC_API Trmm<double2>(const Layout, const Side, const Tri
const cl_mem, const size_t, const size_t,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
+ const size_t, const size_t,
+ const half,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
-// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
+// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
template <typename T>
StatusCode Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
const size_t, const size_t,
@@ -2017,6 +2054,12 @@ template StatusCode PUBLIC_API Trsm<double2>(const Layout, const Side, const Tri
const cl_mem, const size_t, const size_t,
cl_mem, const size_t, const size_t,
cl_command_queue*, cl_event*);
+template StatusCode PUBLIC_API Trsm<half>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
+ const size_t, const size_t,
+ const half,
+ const cl_mem, const size_t, const size_t,
+ cl_mem, const size_t, const size_t,
+ cl_command_queue*, cl_event*);
// =================================================================================================
diff --git a/src/clblast_c.cc b/src/clblast_c.cc
index c368a03c..2aac907a 100644
--- a/src/clblast_c.cc
+++ b/src/clblast_c.cc
@@ -2208,6 +2208,26 @@ StatusCode CLBlastZgemm(const Layout layout, const Transpose a_transpose, const
queue, event);
return static_cast<StatusCode>(status);
}
+StatusCode CLBlastHgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event) {
+ auto status = clblast::Gemm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Transpose>(b_transpose),
+ m, n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event);
+ return static_cast<StatusCode>(status);
+}
// SYMM
StatusCode CLBlastSsymm(const Layout layout, const Side side, const Triangle triangle,
@@ -2290,6 +2310,26 @@ StatusCode CLBlastZsymm(const Layout layout, const Side side, const Triangle tri
queue, event);
return static_cast<StatusCode>(status);
}
+StatusCode CLBlastHsymm(const Layout layout, const Side side, const Triangle triangle,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event) {
+ auto status = clblast::Symm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Side>(side),
+ static_cast<clblast::Triangle>(triangle),
+ m, n,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event);
+ return static_cast<StatusCode>(status);
+}
// HEMM
StatusCode CLBlastChemm(const Layout layout, const Side side, const Triangle triangle,
@@ -2406,6 +2446,24 @@ StatusCode CLBlastZsyrk(const Layout layout, const Triangle triangle, const Tran
queue, event);
return static_cast<StatusCode>(status);
}
+StatusCode CLBlastHsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
+ const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event) {
+ auto status = clblast::Syrk(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Triangle>(triangle),
+ static_cast<clblast::Transpose>(a_transpose),
+ n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event);
+ return static_cast<StatusCode>(status);
+}
// HERK
StatusCode CLBlastCherk(const Layout layout, const Triangle triangle, const Transpose a_transpose,
@@ -2526,6 +2584,26 @@ StatusCode CLBlastZsyr2k(const Layout layout, const Triangle triangle, const Tra
queue, event);
return static_cast<StatusCode>(status);
}
+StatusCode CLBlastHsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
+ const size_t n, const size_t k,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ const cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ const cl_half beta,
+ cl_mem c_buffer, const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, cl_event* event) {
+ auto status = clblast::Syr2k(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Triangle>(triangle),
+ static_cast<clblast::Transpose>(ab_transpose),
+ n, k,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ beta,
+ c_buffer, c_offset, c_ld,
+ queue, event);
+ return static_cast<StatusCode>(status);
+}
// HER2K
StatusCode CLBlastCher2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose,
@@ -2642,6 +2720,24 @@ StatusCode CLBlastZtrmm(const Layout layout, const Side side, const Triangle tri
queue, event);
return static_cast<StatusCode>(status);
}
+StatusCode CLBlastHtrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event) {
+ auto status = clblast::Trmm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Side>(side),
+ static_cast<clblast::Triangle>(triangle),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Diagonal>(diagonal),
+ m, n,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ queue, event);
+ return static_cast<StatusCode>(status);
+}
// TRSM
StatusCode CLBlastStrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
@@ -2716,6 +2812,24 @@ StatusCode CLBlastZtrsm(const Layout layout, const Side side, const Triangle tri
queue, event);
return static_cast<StatusCode>(status);
}
+StatusCode CLBlastHtrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
+ const size_t m, const size_t n,
+ const cl_half alpha,
+ const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
+ cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
+ cl_command_queue* queue, cl_event* event) {
+ auto status = clblast::Trsm(static_cast<clblast::Layout>(layout),
+ static_cast<clblast::Side>(side),
+ static_cast<clblast::Triangle>(triangle),
+ static_cast<clblast::Transpose>(a_transpose),
+ static_cast<clblast::Diagonal>(diagonal),
+ m, n,
+ alpha,
+ a_buffer, a_offset, a_ld,
+ b_buffer, b_offset, b_ld,
+ queue, event);
+ return static_cast<StatusCode>(status);
+}
// =================================================================================================
diff --git a/src/database.cc b/src/database.cc
index dc72dbdd..e20ae340 100644
--- a/src/database.cc
+++ b/src/database.cc
@@ -33,7 +33,7 @@ const std::vector<Database::DatabaseEntry> Database::database = {
XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble,
XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble,
XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble,
- XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble,
+ XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble,
CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble,
PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble,
TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble,
diff --git a/src/routines/level3/xsymm.cc b/src/routines/level3/xsymm.cc
index a39026f1..d88d4653 100644
--- a/src/routines/level3/xsymm.cc
+++ b/src/routines/level3/xsymm.cc
@@ -127,6 +127,7 @@ StatusCode Xsymm<T>::DoSymm(const Layout layout, const Side side, const Triangle
// =================================================================================================
// Compiles the templated class
+template class Xsymm<half>;
template class Xsymm<float>;
template class Xsymm<double>;
template class Xsymm<float2>;
diff --git a/src/routines/level3/xsyr2k.cc b/src/routines/level3/xsyr2k.cc
index c52e1353..4f86bac5 100644
--- a/src/routines/level3/xsyr2k.cc
+++ b/src/routines/level3/xsyr2k.cc
@@ -20,6 +20,7 @@ namespace clblast {
// =================================================================================================
// Specific implementations to get the memory-type based on a template argument
+template <> const Precision Xsyr2k<half>::precision_ = Precision::kHalf;
template <> const Precision Xsyr2k<float>::precision_ = Precision::kSingle;
template <> const Precision Xsyr2k<double>::precision_ = Precision::kDouble;
template <> const Precision Xsyr2k<float2>::precision_ = Precision::kComplexSingle;
@@ -203,6 +204,7 @@ StatusCode Xsyr2k<T>::DoSyr2k(const Layout layout, const Triangle triangle, cons
// =================================================================================================
// Compiles the templated class
+template class Xsyr2k<half>;
template class Xsyr2k<float>;
template class Xsyr2k<double>;
template class Xsyr2k<float2>;
diff --git a/src/routines/level3/xsyrk.cc b/src/routines/level3/xsyrk.cc
index cfcd4e12..52cb58c0 100644
--- a/src/routines/level3/xsyrk.cc
+++ b/src/routines/level3/xsyrk.cc
@@ -20,6 +20,7 @@ namespace clblast {
// =================================================================================================
// Specific implementations to get the memory-type based on a template argument
+template <> const Precision Xsyrk<half>::precision_ = Precision::kHalf;
template <> const Precision Xsyrk<float>::precision_ = Precision::kSingle;
template <> const Precision Xsyrk<double>::precision_ = Precision::kDouble;
template <> const Precision Xsyrk<float2>::precision_ = Precision::kComplexSingle;
@@ -175,6 +176,7 @@ StatusCode Xsyrk<T>::DoSyrk(const Layout layout, const Triangle triangle, const
// =================================================================================================
// Compiles the templated class
+template class Xsyrk<half>;
template class Xsyrk<float>;
template class Xsyrk<double>;
template class Xsyrk<float2>;
diff --git a/src/routines/level3/xtrmm.cc b/src/routines/level3/xtrmm.cc
index 9e3b27b4..18cbb1c0 100644
--- a/src/routines/level3/xtrmm.cc
+++ b/src/routines/level3/xtrmm.cc
@@ -130,6 +130,7 @@ StatusCode Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle
// =================================================================================================
// Compiles the templated class
+template class Xtrmm<half>;
template class Xtrmm<float>;
template class Xtrmm<double>;
template class Xtrmm<float2>;