diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-25 13:29:53 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-25 13:29:53 +0200 |
commit | 9f8745507020961b1c287febc3a5634b46ccb0e9 (patch) | |
tree | ff776b8b8fcf56529eaeada54a6c05c4bdfff264 /include | |
parent | ac1575056e0f3d7406cc7bcbbdbe71b08feb58ce (diff) |
Added level-3 half-precision routines HGEMM/HSYMM/HSYRK/HSYR2K/HTRMM
Diffstat (limited to 'include')
-rw-r--r-- | include/clblast.h | 12 | ||||
-rw-r--r-- | include/clblast_c.h | 55 | ||||
-rw-r--r-- | include/internal/database.h | 2 | ||||
-rw-r--r-- | include/internal/database/xgemm.h | 12 |
4 files changed, 68 insertions, 13 deletions
diff --git a/include/clblast.h b/include/clblast.h index d7b952ba..64b2610a 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -466,7 +466,7 @@ StatusCode Spr2(const Layout layout, const Triangle triangle, // BLAS level-3 (matrix-matrix) routines // ================================================================================================= -// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM +// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM template <typename T> StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, @@ -477,7 +477,7 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM +// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template <typename T> StatusCode Symm(const Layout layout, const Side side, const Triangle triangle, const size_t m, const size_t n, @@ -499,7 +499,7 @@ StatusCode Hemm(const Layout layout, const Side side, const Triangle triangle, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK +// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK template <typename T> StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, @@ -519,7 +519,7 @@ StatusCode Herk(const Layout layout, const Triangle triangle, const Transpose a_ cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K +// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K template <typename T> StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, @@ -541,7 +541,7 @@ StatusCode Her2k(const Layout layout, const Triangle triangle, const Transpose a cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM +// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM template <typename T> StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, @@ -550,7 +550,7 @@ StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, c cl_mem b_buffer, const size_t b_offset, const size_t b_ld, cl_command_queue* queue, cl_event* event = nullptr); -// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM +// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM template <typename T> StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, diff --git a/include/clblast_c.h b/include/clblast_c.h index 92392921..40248615 100644 --- a/include/clblast_c.h +++ b/include/clblast_c.h @@ -986,7 +986,7 @@ StatusCode PUBLIC_API CLBlastHspr2(const Layout layout, const Triangle triangle, // BLAS level-3 (matrix-matrix) routines // ================================================================================================= -// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM +// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM StatusCode PUBLIC_API CLBlastSgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, const float alpha, @@ -1019,8 +1019,16 @@ StatusCode PUBLIC_API CLBlastZgemm(const Layout layout, const Transpose a_transp const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event); -// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM +// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM StatusCode PUBLIC_API CLBlastSsymm(const Layout layout, const Side side, const Triangle triangle, const size_t m, const size_t n, const float alpha, @@ -1053,6 +1061,14 @@ StatusCode PUBLIC_API CLBlastZsymm(const Layout layout, const Side side, const T const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHsymm(const Layout layout, const Side side, const Triangle triangle, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event); // Hermitian matrix-matrix multiplication: CHEMM/ZHEMM StatusCode PUBLIC_API CLBlastChemm(const Layout layout, const Side side, const Triangle triangle, @@ -1072,7 +1088,7 @@ StatusCode PUBLIC_API CLBlastZhemm(const Layout layout, const Side side, const T cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); -// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK +// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK StatusCode PUBLIC_API CLBlastSsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, const float alpha, @@ -1101,6 +1117,13 @@ StatusCode PUBLIC_API CLBlastZsyrk(const Layout layout, const Triangle triangle, const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event); // Rank-K update of a hermitian matrix: CHERK/ZHERK StatusCode PUBLIC_API CLBlastCherk(const Layout layout, const Triangle triangle, const Transpose a_transpose, @@ -1118,7 +1141,7 @@ StatusCode PUBLIC_API CLBlastZherk(const Layout layout, const Triangle triangle, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); -// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K +// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K StatusCode PUBLIC_API CLBlastSsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, const float alpha, @@ -1151,6 +1174,14 @@ StatusCode PUBLIC_API CLBlastZsyr2k(const Layout layout, const Triangle triangle const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event); // Rank-2K update of a hermitian matrix: CHER2K/ZHER2K StatusCode PUBLIC_API CLBlastCher2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, @@ -1170,7 +1201,7 @@ StatusCode PUBLIC_API CLBlastZher2k(const Layout layout, const Triangle triangle cl_mem c_buffer, const size_t c_offset, const size_t c_ld, cl_command_queue* queue, cl_event* event); -// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM +// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM StatusCode PUBLIC_API CLBlastStrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, const float alpha, @@ -1195,8 +1226,14 @@ StatusCode PUBLIC_API CLBlastZtrmm(const Layout layout, const Side side, const T const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, cl_mem b_buffer, const size_t b_offset, const size_t b_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHtrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event); -// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM +// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM StatusCode PUBLIC_API CLBlastStrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, const float alpha, @@ -1221,6 +1258,12 @@ StatusCode PUBLIC_API CLBlastZtrsm(const Layout layout, const Side side, const T const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, cl_mem b_buffer, const size_t b_offset, const size_t b_ld, cl_command_queue* queue, cl_event* event); +StatusCode PUBLIC_API CLBlastHtrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event); // ================================================================================================= diff --git a/include/internal/database.h b/include/internal/database.h index 34629bf5..f93eaa22 100644 --- a/include/internal/database.h +++ b/include/internal/database.h @@ -71,7 +71,7 @@ class Database { static const DatabaseEntry XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble; static const DatabaseEntry XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble; static const DatabaseEntry XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble; - static const DatabaseEntry XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble; + static const DatabaseEntry XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble; static const DatabaseEntry CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble; static const DatabaseEntry PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble; static const DatabaseEntry TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble; diff --git a/include/internal/database/xgemm.h b/include/internal/database/xgemm.h index 9ca2bff5..647188e9 100644 --- a/include/internal/database/xgemm.h +++ b/include/internal/database/xgemm.h @@ -14,6 +14,18 @@ namespace clblast { // ================================================================================================= +const Database::DatabaseEntry Database::XgemmHalf = { + "Xgemm", Precision::kHalf, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"KWG",16}, {"KWI",2}, {"MDIMA",8}, {"MDIMC",8}, {"MWG",32}, {"NDIMB",8}, {"NDIMC",8}, {"NWG",64}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} } }, + } + }, + } +}; + +// ================================================================================================= + const Database::DatabaseEntry Database::XgemmSingle = { "Xgemm", Precision::kSingle, { { // AMD GPUs |