diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-25 13:29:53 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-05-25 13:29:53 +0200 |
commit | 9f8745507020961b1c287febc3a5634b46ccb0e9 (patch) | |
tree | ff776b8b8fcf56529eaeada54a6c05c4bdfff264 /src | |
parent | ac1575056e0f3d7406cc7bcbbdbe71b08feb58ce (diff) |
Added level-3 half-precision routines HGEMM/HSYMM/HSYRK/HSYR2K/HTRMM
Diffstat (limited to 'src')
-rw-r--r-- | src/clblast.cc | 55 | ||||
-rw-r--r-- | src/clblast_c.cc | 114 | ||||
-rw-r--r-- | src/database.cc | 2 | ||||
-rw-r--r-- | src/routines/level3/xsymm.cc | 1 | ||||
-rw-r--r-- | src/routines/level3/xsyr2k.cc | 2 | ||||
-rw-r--r-- | src/routines/level3/xsyrk.cc | 2 | ||||
-rw-r--r-- | src/routines/level3/xtrmm.cc | 1 |
7 files changed, 170 insertions, 7 deletions
diff --git a/src/clblast.cc b/src/clblast.cc index 449c7321..07322327 100644 --- a/src/clblast.cc +++ b/src/clblast.cc @@ -1613,7 +1613,7 @@ template StatusCode PUBLIC_API Spr2<half>(const Layout, const Triangle, // BLAS level-3 (matrix-matrix) routines // ================================================================================================= -// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM +// General matrix-matrix multiplication: SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM template <typename T> StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, const size_t m, const size_t n, const size_t k, @@ -1667,8 +1667,16 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons const double2, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + const cl_mem, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); -// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM +// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template <typename T> StatusCode Symm(const Layout layout, const Side side, const Triangle triangle, const size_t m, const size_t n, @@ -1722,6 +1730,14 @@ template StatusCode PUBLIC_API Symm<double2>(const Layout, const Side, const Tri const double2, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Symm<half>(const Layout, const Side, const Triangle, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + const cl_mem, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // Hermitian matrix-matrix multiplication: CHEMM/ZHEMM template <typename T> @@ -1762,7 +1778,7 @@ template StatusCode PUBLIC_API Hemm<double2>(const Layout, const Side, const Tri cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); -// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK +// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK template <typename T> StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, const size_t n, const size_t k, @@ -1810,6 +1826,13 @@ template StatusCode PUBLIC_API Syrk<double2>(const Layout, const Triangle, const const double2, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Syrk<half>(const Layout, const Triangle, const Transpose, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // Rank-K update of a hermitian matrix: CHERK/ZHERK template <typename T> @@ -1846,7 +1869,7 @@ template StatusCode PUBLIC_API Herk<double>(const Layout, const Triangle, const cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); -// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K +// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K template <typename T> StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, const size_t n, const size_t k, @@ -1900,6 +1923,14 @@ template StatusCode PUBLIC_API Syr2k<double2>(const Layout, const Triangle, cons const double2, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Syr2k<half>(const Layout, const Triangle, const Transpose, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + const cl_mem, const size_t, const size_t, + const half, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // Rank-2K update of a hermitian matrix: CHER2K/ZHER2K template <typename T, typename U> @@ -1940,7 +1971,7 @@ template StatusCode PUBLIC_API Her2k<double2,double>(const Layout, const Triangl cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); -// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM +// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM template <typename T> StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, const size_t m, const size_t n, @@ -1982,8 +2013,14 @@ template StatusCode PUBLIC_API Trmm<double2>(const Layout, const Side, const Tri const cl_mem, const size_t, const size_t, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triangle, const Transpose, const Diagonal, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); -// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM +// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM template <typename T> StatusCode Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal, const size_t, const size_t, @@ -2017,6 +2054,12 @@ template StatusCode PUBLIC_API Trsm<double2>(const Layout, const Side, const Tri const cl_mem, const size_t, const size_t, cl_mem, const size_t, const size_t, cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Trsm<half>(const Layout, const Side, const Triangle, const Transpose, const Diagonal, + const size_t, const size_t, + const half, + const cl_mem, const size_t, const size_t, + cl_mem, const size_t, const size_t, + cl_command_queue*, cl_event*); // ================================================================================================= diff --git a/src/clblast_c.cc b/src/clblast_c.cc index c368a03c..2aac907a 100644 --- a/src/clblast_c.cc +++ b/src/clblast_c.cc @@ -2208,6 +2208,26 @@ StatusCode CLBlastZgemm(const Layout layout, const Transpose a_transpose, const queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHgemm(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Gemm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Transpose>(b_transpose), + m, n, k, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event); + return static_cast<StatusCode>(status); +} // SYMM StatusCode CLBlastSsymm(const Layout layout, const Side side, const Triangle triangle, @@ -2290,6 +2310,26 @@ StatusCode CLBlastZsymm(const Layout layout, const Side side, const Triangle tri queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHsymm(const Layout layout, const Side side, const Triangle triangle, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Symm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Side>(side), + static_cast<clblast::Triangle>(triangle), + m, n, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event); + return static_cast<StatusCode>(status); +} // HEMM StatusCode CLBlastChemm(const Layout layout, const Side side, const Triangle triangle, @@ -2406,6 +2446,24 @@ StatusCode CLBlastZsyrk(const Layout layout, const Triangle triangle, const Tran queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHsyrk(const Layout layout, const Triangle triangle, const Transpose a_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Syrk(static_cast<clblast::Layout>(layout), + static_cast<clblast::Triangle>(triangle), + static_cast<clblast::Transpose>(a_transpose), + n, k, + alpha, + a_buffer, a_offset, a_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event); + return static_cast<StatusCode>(status); +} // HERK StatusCode CLBlastCherk(const Layout layout, const Triangle triangle, const Transpose a_transpose, @@ -2526,6 +2584,26 @@ StatusCode CLBlastZsyr2k(const Layout layout, const Triangle triangle, const Tra queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHsyr2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, + const size_t n, const size_t k, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + const cl_half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Syr2k(static_cast<clblast::Layout>(layout), + static_cast<clblast::Triangle>(triangle), + static_cast<clblast::Transpose>(ab_transpose), + n, k, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + beta, + c_buffer, c_offset, c_ld, + queue, event); + return static_cast<StatusCode>(status); +} // HER2K StatusCode CLBlastCher2k(const Layout layout, const Triangle triangle, const Transpose ab_transpose, @@ -2642,6 +2720,24 @@ StatusCode CLBlastZtrmm(const Layout layout, const Side side, const Triangle tri queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHtrmm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Trmm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Side>(side), + static_cast<clblast::Triangle>(triangle), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Diagonal>(diagonal), + m, n, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + queue, event); + return static_cast<StatusCode>(status); +} // TRSM StatusCode CLBlastStrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, @@ -2716,6 +2812,24 @@ StatusCode CLBlastZtrsm(const Layout layout, const Side side, const Triangle tri queue, event); return static_cast<StatusCode>(status); } +StatusCode CLBlastHtrsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal, + const size_t m, const size_t n, + const cl_half alpha, + const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + cl_command_queue* queue, cl_event* event) { + auto status = clblast::Trsm(static_cast<clblast::Layout>(layout), + static_cast<clblast::Side>(side), + static_cast<clblast::Triangle>(triangle), + static_cast<clblast::Transpose>(a_transpose), + static_cast<clblast::Diagonal>(diagonal), + m, n, + alpha, + a_buffer, a_offset, a_ld, + b_buffer, b_offset, b_ld, + queue, event); + return static_cast<StatusCode>(status); +} // ================================================================================================= diff --git a/src/database.cc b/src/database.cc index dc72dbdd..e20ae340 100644 --- a/src/database.cc +++ b/src/database.cc @@ -33,7 +33,7 @@ const std::vector<Database::DatabaseEntry> Database::database = { XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble, XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble, XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble, - XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble, + XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble, CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble, PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble, TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble, diff --git a/src/routines/level3/xsymm.cc b/src/routines/level3/xsymm.cc index a39026f1..d88d4653 100644 --- a/src/routines/level3/xsymm.cc +++ b/src/routines/level3/xsymm.cc @@ -127,6 +127,7 @@ StatusCode Xsymm<T>::DoSymm(const Layout layout, const Side side, const Triangle // ================================================================================================= // Compiles the templated class +template class Xsymm<half>; template class Xsymm<float>; template class Xsymm<double>; template class Xsymm<float2>; diff --git a/src/routines/level3/xsyr2k.cc b/src/routines/level3/xsyr2k.cc index c52e1353..4f86bac5 100644 --- a/src/routines/level3/xsyr2k.cc +++ b/src/routines/level3/xsyr2k.cc @@ -20,6 +20,7 @@ namespace clblast { // ================================================================================================= // Specific implementations to get the memory-type based on a template argument +template <> const Precision Xsyr2k<half>::precision_ = Precision::kHalf; template <> const Precision Xsyr2k<float>::precision_ = Precision::kSingle; template <> const Precision Xsyr2k<double>::precision_ = Precision::kDouble; template <> const Precision Xsyr2k<float2>::precision_ = Precision::kComplexSingle; @@ -203,6 +204,7 @@ StatusCode Xsyr2k<T>::DoSyr2k(const Layout layout, const Triangle triangle, cons // ================================================================================================= // Compiles the templated class +template class Xsyr2k<half>; template class Xsyr2k<float>; template class Xsyr2k<double>; template class Xsyr2k<float2>; diff --git a/src/routines/level3/xsyrk.cc b/src/routines/level3/xsyrk.cc index cfcd4e12..52cb58c0 100644 --- a/src/routines/level3/xsyrk.cc +++ b/src/routines/level3/xsyrk.cc @@ -20,6 +20,7 @@ namespace clblast { // ================================================================================================= // Specific implementations to get the memory-type based on a template argument +template <> const Precision Xsyrk<half>::precision_ = Precision::kHalf; template <> const Precision Xsyrk<float>::precision_ = Precision::kSingle; template <> const Precision Xsyrk<double>::precision_ = Precision::kDouble; template <> const Precision Xsyrk<float2>::precision_ = Precision::kComplexSingle; @@ -175,6 +176,7 @@ StatusCode Xsyrk<T>::DoSyrk(const Layout layout, const Triangle triangle, const // ================================================================================================= // Compiles the templated class +template class Xsyrk<half>; template class Xsyrk<float>; template class Xsyrk<double>; template class Xsyrk<float2>; diff --git a/src/routines/level3/xtrmm.cc b/src/routines/level3/xtrmm.cc index 9e3b27b4..18cbb1c0 100644 --- a/src/routines/level3/xtrmm.cc +++ b/src/routines/level3/xtrmm.cc @@ -130,6 +130,7 @@ StatusCode Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle // ================================================================================================= // Compiles the templated class +template class Xtrmm<half>; template class Xtrmm<float>; template class Xtrmm<double>; template class Xtrmm<float2>; |