diff options
-rw-r--r-- | CHANGELOG | 1 | ||||
-rw-r--r-- | README.md | 22 | ||||
-rw-r--r-- | doc/clblast.md | 37 | ||||
-rw-r--r-- | include/clblast.h | 12 | ||||
-rw-r--r-- | include/clblast_c.h | 55 | ||||
-rw-r--r-- | include/internal/database.h | 2 | ||||
-rw-r--r-- | include/internal/database/xgemm.h | 12 | ||||
-rw-r--r-- | scripts/generator/generator.py | 18 | ||||
-rw-r--r-- | src/clblast.cc | 55 | ||||
-rw-r--r-- | src/clblast_c.cc | 114 | ||||
-rw-r--r-- | src/database.cc | 2 | ||||
-rw-r--r-- | src/routines/level3/xsymm.cc | 1 | ||||
-rw-r--r-- | src/routines/level3/xsyr2k.cc | 2 | ||||
-rw-r--r-- | src/routines/level3/xsyrk.cc | 2 | ||||
-rw-r--r-- | src/routines/level3/xtrmm.cc | 1 |
15 files changed, 296 insertions, 40 deletions
@@ -4,6 +4,7 @@ Development version (next release) - Added half-precision routines: * Level-1: HSWAP/HSCAL/HCOPY/HAXPY/HDOT/HNRM2/HASUM/HSUM/iHAMAX/iHMAX/iHMIN * Level-2: HGEMV/HGBMV/HHEMV/HHBMV/HHPMV/HSYMV/HSBMV/HSPMV/HTRMV/HTBMV/HTPMV/HGER/HSYR/HSPR/HSYR2/HSPR2 + * Level-3: HGEMM/HSYMM/HSYRK/HSYR2K/HTRMM Version 0.7.1 - Improved performance of large power-of-2 xGEMM kernels for AMD GPUs @@ -128,7 +128,7 @@ If your device is not (yet) among this list or if you want to tune CLBlast for s cmake -DTUNERS=ON .. -Note that CLBlast's tuners are based on the CLTune auto-tuning library, which has to be installed separately (version 2.3.0 or higher). CLTune is available from GitHub. +Note that CLBlast's tuners are based on the CLTune auto-tuning library, which has to be installed separately (version 2.3.1 or higher). CLTune is available from GitHub. Compiling with `-DTUNERS=ON` will generate a number of tuners, each named `clblast_tuner_xxxxx`, in which `xxxxx` corresponds to a `.opencl` kernel file as found in `src/kernels`. These kernels corresponds to routines (e.g. `xgemm`) or to common pre-processing or post-processing kernels (`copy` and `transpose`). Running such a tuner will test a number of parameter-value combinations on your device and report which one gave the best performance. Running `make alltuners` runs all tuners for all precisions in one go. You can set the default device and platform for `alltuners` by setting the `DEFAULT_DEVICE` and `DEFAULT_PLATFORM` environmental variables before running CMake. @@ -224,16 +224,16 @@ CLBlast is in active development but already supports almost all the BLAS routin | xSYR2 | ✔ | ✔ | - | - | ✔ | | xSPR2 | ✔ | ✔ | - | - | ✔ | -| Level-3 | S | D | C | Z | -| ---------|---|---|---|---| -| xGEMM | ✔ | ✔ | ✔ | ✔ | -| xSYMM | ✔ | ✔ | ✔ | ✔ | -| xHEMM | - | - | ✔ | ✔ | -| xSYRK | ✔ | ✔ | ✔ | ✔ | -| xHERK | - | - | ✔ | ✔ | -| xSYR2K | ✔ | ✔ | ✔ | ✔ | -| xHER2K | - | - | ✔ | ✔ | -| xTRMM | ✔ | ✔ | ✔ | ✔ | +| Level-3 | S | D | C | Z | H | +| ---------|---|---|---|---|---| +| xGEMM | ✔ | ✔ | ✔ | ✔ | ✔ | +| xSYMM | ✔ | ✔ | ✔ | ✔ | ✔ | +| xHEMM | - | - | ✔ | ✔ | - | +| xSYRK | ✔ | ✔ | ✔ | ✔ | ✔ | +| xHERK | - | - | ✔ | ✔ | - | +| xSYR2K | ✔ | ✔ | ✔ | ✔ | ✔ | +| xHER2K | - | - | ✔ | ✔ | - | +| xTRMM | ✔ | ✔ | ✔ | ✔ | ✔ | In addition, some non-BLAS routines are also supported by CLBlast. They are experimental and should be used with care: diff --git a/doc/clblast.md b/doc/clblast.md index 6f3f09c2..8dbb97e4 100644 --- a/doc/clblast.md +++ b/doc/clblast.md @@ -2075,6 +2075,14 @@ StatusCode CLBlastZgemm(const Layout layout, const Transpose a_transpose, const const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event) +StatusCode CLBlastHgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) ``` Arguments to GEMM: @@ -2153,6 +2161,14 @@ StatusCode CLBlastZsymm(const Layout layout, const Side side, const Triangle tri const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event) +StatusCode CLBlastHsymm(const Layout layout, const Side side, const Triangle triangle, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) ``` Arguments to SYMM: @@ -2286,6 +2302,13 @@ StatusCode CLBlastZsyrk(const Layout layout, const Triangle triangle, const Tran const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event) +StatusCode CLBlastHsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) ``` Arguments to SYRK: @@ -2415,6 +2438,14 @@ StatusCode CLBlastZsyr2k(const Layout layout, const Triangle triangle, const Tra const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event) +StatusCode CLBlastHsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) ``` Arguments to SYR2K: @@ -2543,6 +2574,12 @@ StatusCode CLBlastZtrmm(const Layout layout, const Side side, const Triangle tri const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, cl_mem b_buffer, const size_t b_offset, const size_t b_ld, cl_command_queue* queue, cl_event* event) +StatusCode CLBlastHtrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event) ``` Arguments to TRMM: diff --git a/include/clblast.h b/include/clblast.h index d7b952ba..64b2610a 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -466,7 +466,7 @@ StatusCode Spr2(const Layout layout, const Triangle triangle, // BLAS level-3 (matrix-matrix) routines // ================================================================================================= -// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM +// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM template <typename T> StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, @@ -477,7 +477,7 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM +// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template <typename T> StatusCode Symm(const Layout layout, const Side side, const Triangle triangle, const size_t m, const size_t n, @@ -499,7 +499,7 @@ StatusCode Hemm(const Layout layout, const Side side, const Triangle triangle, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK +// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK template <typename T> StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, @@ -519,7 +519,7 @@ StatusCode Herk(const Layout layout, const Triangle triangle, const Transpose a_ cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K +// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K template <typename T> StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, @@ -541,7 +541,7 @@ StatusCode Her2k(const Layout layout, const Triangle triangle, const Transpose a cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM +// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM template <typename T> StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, @@ -550,7 +550,7 @@ StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, c cl_mem b_buffer, const size_t b_offset, const size_t b_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM +// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM template <typename T> StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, diff --git a/include/clblast_c.h b/include/clblast_c.h index 92392921..40248615 100644 --- a/include/clblast_c.h +++ b/include/clblast_c.h @@ -986,7 +986,7 @@ StatusCode PUBLIC_API CLBlastHspr2(const Layout layout, const Triangle triangle, // BLAS level-3 (matrix-matrix) routines // ================================================================================================= -// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM +// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM StatusCode PUBLIC_API CLBlastSgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, const float alpha, @@ -1019,8 +1019,16 @@ StatusCode PUBLIC_API CLBlastZgemm(const Layout layout, const Transpose a_transp const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event); -// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM +// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM StatusCode PUBLIC_API CLBlastSsymm(const Layout layout, const Side side, const Triangle triangle, const size_t m, const size_t n, const float alpha, @@ -1053,6 +1061,14 @@ StatusCode PUBLIC_API CLBlastZsymm(const Layout layout, const Side side, const T const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHsymm(const Layout layout, const Side side, const Triangle triangle, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event); // Hermitian matrix-matrix multiplication: CHEMM/ZHEMM StatusCode PUBLIC_API CLBlastChemm(const Layout layout, const Side side, const Triangle triangle, @@ -1072,7 +1088,7 @@ StatusCode PUBLIC_API CLBlastZhemm(const Layout layout, const Side side, const T cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); -// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK +// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK StatusCode PUBLIC_API CLBlastSsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, const float alpha, @@ -1101,6 +1117,13 @@ StatusCode PUBLIC_API CLBlastZsyrk(const Layout layout, const Triangle triangle, const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event); // Rank-K update of a hermitian matrix: CHERK/ZHERK StatusCode PUBLIC_API CLBlastCherk(const Layout layout, const Triangle triangle, const Transpose a_transpose, @@ -1118,7 +1141,7 @@ StatusCode PUBLIC_API CLBlastZherk(const Layout layout, const Triangle triangle, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); -// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K +// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K StatusCode PUBLIC_API CLBlastSsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, const float alpha, @@ -1151,6 +1174,14 @@ StatusCode PUBLIC_API CLBlastZsyr2k(const Layout layout, const Triangle triangle const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event); // Rank-2K update of a hermitian matrix: CHER2K/ZHER2K StatusCode PUBLIC_API CLBlastCher2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, @@ -1170,7 +1201,7 @@ StatusCode PUBLIC_API CLBlastZher2k(const Layout layout, const Triangle triangle cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); -// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM +// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM StatusCode PUBLIC_API CLBlastStrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, const float alpha, @@ -1195,8 +1226,14 @@ StatusCode PUBLIC_API CLBlastZtrmm(const Layout layout, const Side side, const T const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, cl_mem b_buffer, const size_t b_offset, const size_t b_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHtrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event); -// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM +// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM StatusCode PUBLIC_API CLBlastStrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, const float alpha, @@ -1221,6 +1258,12 @@ StatusCode PUBLIC_API CLBlastZtrsm(const Layout layout, const Side side, const T const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, cl_mem b_buffer, const size_t b_offset, const size_t b_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHtrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event); // ================================================================================================= diff --git a/include/internal/database.h b/include/internal/database.h index 34629bf5..f93eaa22 100644 --- a/include/internal/database.h +++ b/include/internal/database.h @@ -71,7 +71,7 @@ class Database { static const DatabaseEntry XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble; static const DatabaseEntry XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble; static const DatabaseEntry XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble; - static const DatabaseEntry XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble; + static const DatabaseEntry XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble; static const DatabaseEntry CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble; static const DatabaseEntry PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble; static const DatabaseEntry TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble; diff --git a/include/internal/database/xgemm.h b/include/internal/database/xgemm.h index 9ca2bff5..647188e9 100644 --- a/include/internal/database/xgemm.h +++ b/include/internal/database/xgemm.h @@ -14,6 +14,18 @@ namespace clblast { // ================================================================================================= +const Database::DatabaseEntry Database::XgemmHalf = { + "Xgemm", Precision::kHalf, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"KWG",16}, {"KWI",2}, {"MDIMA",8}, {"MDIMC",8}, {"MWG",32}, {"NDIMB",8}, {"NDIMC",8}, {"NWG",64}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} } }, + } + }, + } +}; + +// ================================================================================================= + const Database::DatabaseEntry Database::XgemmSingle = { "Xgemm", Precision::kSingle, { { // AMD GPUs diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 8dd1f77a..7a8ff9f8 100644 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -109,15 +109,15 @@ routines = [ Routine(True, True, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], ["alpha"], "", "Symmetric packed rank-2 matrix update", "", []), ], [ # Level 3: matrix-matrix - Routine(True, True, "3", "gemm", T, [S,D,C,Z], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "General matrix-matrix multiplication", "", []), - Routine(True, True, "3", "symm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "", []), - Routine(True, True, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "", []), - Routine(True, True, "3", "syrk", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "", []), - Routine(True, True, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "", []), - Routine(True, True, "3", "syr2k", T, [S,D,C,Z], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "", []), - Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "", []), - Routine(True, True, "3", "trmm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Triangular matrix-matrix multiplication", "", []), - Routine(False, True, "3", "trsm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Solves a triangular system of equations", "", []), + Routine(True, True, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "General matrix-matrix multiplication", "", []), + Routine(True, True, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "", []), + Routine(True, True, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "", []), + Routine(True, True, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "", []), + Routine(True, True, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "", []), + Routine(True, True, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "", []), + Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "", []), + Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Triangular matrix-matrix multiplication", "", []), + Routine(False, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Solves a triangular system of equations", "", []), ]] # ================================================================================================== diff --git a/src/clblast.cc b/src/clblast.cc index 449c7321..07322327 100644 --- a/src/clblast.cc +++ b/src/clblast.cc @@ -1613,7 +1613,7 @@ template StatusCode PUBLIC_API Spr2<half>(const Layout, const Triangle, // BLAS level-3 (matrix-matrix) routines // ================================================================================================= -// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM +// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM template <typename T> StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, @@ -1667,8 +1667,16 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons const double2, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + const cl_mem, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); -// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM +// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template <typename T> StatusCode Symm(const Layout layout, const Side side, const Triangle triangle, const size_t m, const size_t n, @@ -1722,6 +1730,14 @@ template StatusCode PUBLIC_API Symm<double2>(const Layout, const Side, const Tri const double2, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Symm<half>(const Layout, const Side, const Triangle, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + const cl_mem, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // Hermitian matrix-matrix multiplication: CHEMM/ZHEMM template <typename T> @@ -1762,7 +1778,7 @@ template StatusCode PUBLIC_API Hemm<double2>(const Layout, const Side, const Tri cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); -// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK +// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK template <typename T> StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, @@ -1810,6 +1826,13 @@ template StatusCode PUBLIC_API Syrk<double2>(const Layout, const Triangle, const const double2, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Syrk<half>(const Layout, const Triangle, const Transpose, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // Rank-K update of a hermitian matrix: CHERK/ZHERK template <typename T> @@ -1846,7 +1869,7 @@ template StatusCode PUBLIC_API Herk<double>(const Layout, const Triangle, const cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); -// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K +// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K template <typename T> StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, @@ -1900,6 +1923,14 @@ template StatusCode PUBLIC_API Syr2k<double2>(const Layout, const Triangle, cons const double2, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Syr2k<half>(const Layout, const Triangle, const Transpose, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + const cl_mem, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // Rank-2K update of a hermitian matrix: CHER2K/ZHER2K template <typename T, typename U> @@ -1940,7 +1971,7 @@ template StatusCode PUBLIC_API Her2k<double2,double>(const Layout, const Triangl cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); -// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM +// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM template <typename T> StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, @@ -1982,8 +2013,14 @@ template StatusCode PUBLIC_API Trmm<double2>(const Layout, const Side, const Tri const cl_mem, const size_t, const size_t, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triangle, const Transpose, const Diagonal, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); -// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM +// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM template <typename T> StatusCode Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, @@ -2017,6 +2054,12 @@ template StatusCode PUBLIC_API Trsm<double2>(const Layout, const Side, const Tri const cl_mem, const size_t, const size_t, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Trsm<half>(const Layout, const Side, const Triangle, const Transpose, const Diagonal, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // ================================================================================================= diff --git a/src/clblast_c.cc b/src/clblast_c.cc index c368a03c..2aac907a 100644 --- a/src/clblast_c.cc +++ b/src/clblast_c.cc @@ -2208,6 +2208,26 @@ StatusCode CLBlastZgemm(const Layout layout, const Transpose a_transpose, const queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Gemm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event); + return static_cast<StatusCode>(status); +} // SYMM StatusCode CLBlastSsymm(const Layout layout, const Side side, const Triangle triangle, @@ -2290,6 +2310,26 @@ StatusCode CLBlastZsymm(const Layout layout, const Side side, const Triangle tri queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHsymm(const Layout layout, const Side side, const Triangle triangle, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Symm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Side>(side), + static_cast<clblast::Triangle>(triangle), + m, n, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event); + return static_cast<StatusCode>(status); +} // HEMM StatusCode CLBlastChemm(const Layout layout, const Side side, const Triangle triangle, @@ -2406,6 +2446,24 @@ StatusCode CLBlastZsyrk(const Layout layout, const Triangle triangle, const Tran queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Syrk(static_cast<clblast::Layout>(layout), + static_cast<clblast::Triangle>(triangle), + static_cast<clblast::Transpose>(a_transpose), + n, k, + alpha, + a_buffer, a_offset, a_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event); + return static_cast<StatusCode>(status); +} // HERK StatusCode CLBlastCherk(const Layout layout, const Triangle triangle, const Transpose a_transpose, @@ -2526,6 +2584,26 @@ StatusCode CLBlastZsyr2k(const Layout layout, const Triangle triangle, const Tra queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Syr2k(static_cast<clblast::Layout>(layout), + static_cast<clblast::Triangle>(triangle), + static_cast<clblast::Transpose>(ab_transpose), + n, k, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event); + return static_cast<StatusCode>(status); +} // HER2K StatusCode CLBlastCher2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, @@ -2642,6 +2720,24 @@ StatusCode CLBlastZtrmm(const Layout layout, const Side side, const Triangle tri queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHtrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Trmm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Side>(side), + static_cast<clblast::Triangle>(triangle), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Diagonal>(diagonal), + m, n, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + queue, event); + return static_cast<StatusCode>(status); +} // TRSM StatusCode CLBlastStrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, @@ -2716,6 +2812,24 @@ StatusCode CLBlastZtrsm(const Layout layout, const Side side, const Triangle tri queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHtrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Trsm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Side>(side), + static_cast<clblast::Triangle>(triangle), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Diagonal>(diagonal), + m, n, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + queue, event); + return static_cast<StatusCode>(status); +} // ================================================================================================= diff --git a/src/database.cc b/src/database.cc index dc72dbdd..e20ae340 100644 --- a/src/database.cc +++ b/src/database.cc @@ -33,7 +33,7 @@ const std::vector<Database::DatabaseEntry> Database::database = { XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble, XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble, XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble, - XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble, + XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble, CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble, PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble, TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble, diff --git a/src/routines/level3/xsymm.cc b/src/routines/level3/xsymm.cc index a39026f1..d88d4653 100644 --- a/src/routines/level3/xsymm.cc +++ b/src/routines/level3/xsymm.cc @@ -127,6 +127,7 @@ StatusCode Xsymm<T>::DoSymm(const Layout layout, const Side side, const Triangle // ================================================================================================= // Compiles the templated class +template class Xsymm<half>; template class Xsymm<float>; template class Xsymm<double>; template class Xsymm<float2>; diff --git a/src/routines/level3/xsyr2k.cc b/src/routines/level3/xsyr2k.cc index c52e1353..4f86bac5 100644 --- a/src/routines/level3/xsyr2k.cc +++ b/src/routines/level3/xsyr2k.cc @@ -20,6 +20,7 @@ namespace clblast { // ================================================================================================= // Specific implementations to get the memory-type based on a template argument +template <> const Precision Xsyr2k<half>::precision_ = Precision::kHalf; template <> const Precision Xsyr2k<float>::precision_ = Precision::kSingle; template <> const Precision Xsyr2k<double>::precision_ = Precision::kDouble; template <> const Precision Xsyr2k<float2>::precision_ = Precision::kComplexSingle; @@ -203,6 +204,7 @@ StatusCode Xsyr2k<T>::DoSyr2k(const Layout layout, const Triangle triangle, cons // ================================================================================================= // Compiles the templated class +template class Xsyr2k<half>; template class Xsyr2k<float>; template class Xsyr2k<double>; template class Xsyr2k<float2>; diff --git a/src/routines/level3/xsyrk.cc b/src/routines/level3/xsyrk.cc index cfcd4e12..52cb58c0 100644 --- a/src/routines/level3/xsyrk.cc +++ b/src/routines/level3/xsyrk.cc @@ -20,6 +20,7 @@ namespace clblast { // ================================================================================================= // Specific implementations to get the memory-type based on a template argument +template <> const Precision Xsyrk<half>::precision_ = Precision::kHalf; template <> const Precision Xsyrk<float>::precision_ = Precision::kSingle; template <> const Precision Xsyrk<double>::precision_ = Precision::kDouble; template <> const Precision Xsyrk<float2>::precision_ = Precision::kComplexSingle; @@ -175,6 +176,7 @@ StatusCode Xsyrk<T>::DoSyrk(const Layout layout, const Triangle triangle, const // ================================================================================================= // Compiles the templated class +template class Xsyrk<half>; template class Xsyrk<float>; template class Xsyrk<double>; template class Xsyrk<float2>; diff --git a/src/routines/level3/xtrmm.cc b/src/routines/level3/xtrmm.cc index 9e3b27b4..18cbb1c0 100644 --- a/src/routines/level3/xtrmm.cc +++ b/src/routines/level3/xtrmm.cc @@ -130,6 +130,7 @@ StatusCode Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle // ================================================================================================= // Compiles the templated class +template class Xtrmm<half>; template class Xtrmm<float>; template class Xtrmm<double>; template class Xtrmm<float2>; |