diff options
-rwxr-xr-x | scripts/generator/generator.py | 2 | ||||
-rw-r--r-- | src/clblast.cpp | 119 |
2 files changed, 68 insertions, 53 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 35d902b7..6591cbf7 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -42,7 +42,7 @@ FILES = [ "/src/clblast_netlib_c.cpp", ] HEADER_LINES = [117, 73, 118, 22, 29, 41, 65, 32] -FOOTER_LINES = [17, 80, 19, 18, 6, 6, 9, 2] +FOOTER_LINES = [17, 95, 19, 18, 6, 6, 9, 2] # Different possibilities for requirements ald_m = "The value of `a_ld` must be at least `m`." diff --git a/src/clblast.cpp b/src/clblast.cpp index 6a47316e..e0f8add2 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2170,6 +2170,71 @@ StatusCode ClearCache() { return StatusCode::kSuccess; } +template <typename Real, typename Complex> +void FillCacheForPrecision(Queue &queue) { + try { + + // Runs all the level 1 set-up functions + Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr); + Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr); + Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr); + Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr); + Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr); + Xdot<Real>(queue, nullptr); + Xdotu<Complex>(queue, nullptr); + Xdotc<Complex>(queue, nullptr); + Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr); + Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr); + Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr); + Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr); + Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr); + Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr); + + // Runs all the level 2 set-up functions + Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr); + Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr); + Xhemv<Complex>(queue, nullptr); + Xhbmv<Complex>(queue, nullptr); + Xhpmv<Complex>(queue, nullptr); + Xsymv<Real>(queue, nullptr); + Xsbmv<Real>(queue, nullptr); + Xspmv<Real>(queue, nullptr); + Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr); + Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr); + Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr); + Xger<Real>(queue, nullptr); + Xgeru<Complex>(queue, nullptr); + Xgerc<Complex>(queue, nullptr); + Xher<Complex,Real>(queue, nullptr); + Xhpr<Complex,Real>(queue, nullptr); + Xher2<Complex>(queue, nullptr); + Xhpr2<Complex>(queue, nullptr); + Xsyr<Real>(queue, nullptr); + Xspr<Real>(queue, nullptr); + Xsyr2<Real>(queue, nullptr); + Xspr2<Real>(queue, nullptr); + + // Runs all the level 3 set-up functions + Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr); + Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr); + Xhemm<Complex>(queue, nullptr); + Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr); + Xherk<Complex,Real>(queue, nullptr); + Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr); + Xher2k<Complex,Real>(queue, nullptr); + Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr); + + // Runs all the non-BLAS set-up functions + Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr); + + } catch(const RuntimeErrorCode &e) { + if (e.status() != StatusCode::kNoDoublePrecision && + e.status() != StatusCode::kNoHalfPrecision) { + throw; + } + } +} + // Fills the cache with all binaries for a specific device // TODO: Add half-precision FP16 set-up calls StatusCode FillCache(const cl_device_id device) { @@ -2180,58 +2245,8 @@ StatusCode FillCache(const cl_device_id device) { auto context = Context(device_cpp); auto queue = Queue(context, device_cpp); - // Runs all the level 1 set-up functions - Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr); - Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr); - Xscal<float>(queue, nullptr); Xscal<double>(queue, nullptr); Xscal<float2>(queue, nullptr); Xscal<double2>(queue, nullptr); - Xcopy<float>(queue, nullptr); Xcopy<double>(queue, nullptr); Xcopy<float2>(queue, nullptr); Xcopy<double2>(queue, nullptr); - Xaxpy<float>(queue, nullptr); Xaxpy<double>(queue, nullptr); Xaxpy<float2>(queue, nullptr); Xaxpy<double2>(queue, nullptr); - Xdot<float>(queue, nullptr); Xdot<double>(queue, nullptr); - Xdotu<float2>(queue, nullptr); Xdotu<double2>(queue, nullptr); - Xdotc<float2>(queue, nullptr); Xdotc<double2>(queue, nullptr); - Xnrm2<float>(queue, nullptr); Xnrm2<double>(queue, nullptr); Xnrm2<float2>(queue, nullptr); Xnrm2<double2>(queue, nullptr); - Xasum<float>(queue, nullptr); Xasum<double>(queue, nullptr); Xasum<float2>(queue, nullptr); Xasum<double2>(queue, nullptr); - Xsum<float>(queue, nullptr); Xsum<double>(queue, nullptr); Xsum<float2>(queue, nullptr); Xsum<double2>(queue, nullptr); - Xamax<float>(queue, nullptr); Xamax<double>(queue, nullptr); Xamax<float2>(queue, nullptr); Xamax<double2>(queue, nullptr); - Xmax<float>(queue, nullptr); Xmax<double>(queue, nullptr); Xmax<float2>(queue, nullptr); Xmax<double2>(queue, nullptr); - Xmin<float>(queue, nullptr); Xmin<double>(queue, nullptr); Xmin<float2>(queue, nullptr); Xmin<double2>(queue, nullptr); - - // Runs all the level 2 set-up functions - Xgemv<float>(queue, nullptr); Xgemv<double>(queue, nullptr); Xgemv<float2>(queue, nullptr); Xgemv<double2>(queue, nullptr); - Xgbmv<float>(queue, nullptr); Xgbmv<double>(queue, nullptr); Xgbmv<float2>(queue, nullptr); Xgbmv<double2>(queue, nullptr); - Xhemv<float2>(queue, nullptr); Xhemv<double2>(queue, nullptr); - Xhbmv<float2>(queue, nullptr); Xhbmv<double2>(queue, nullptr); - Xhpmv<float2>(queue, nullptr); Xhpmv<double2>(queue, nullptr); - Xsymv<float>(queue, nullptr); Xsymv<double>(queue, nullptr); - Xsbmv<float>(queue, nullptr); Xsbmv<double>(queue, nullptr); - Xspmv<float>(queue, nullptr); Xspmv<double>(queue, nullptr); - Xtrmv<float>(queue, nullptr); Xtrmv<double>(queue, nullptr); Xtrmv<float2>(queue, nullptr); Xtrmv<double2>(queue, nullptr); - Xtbmv<float>(queue, nullptr); Xtbmv<double>(queue, nullptr); Xtbmv<float2>(queue, nullptr); Xtbmv<double2>(queue, nullptr); - Xtpmv<float>(queue, nullptr); Xtpmv<double>(queue, nullptr); Xtpmv<float2>(queue, nullptr); Xtpmv<double2>(queue, nullptr); - Xger<float>(queue, nullptr); Xger<double>(queue, nullptr); - Xgeru<float2>(queue, nullptr); Xgeru<double2>(queue, nullptr); - Xgerc<float2>(queue, nullptr); Xgerc<double2>(queue, nullptr); - Xher<float2,float>(queue, nullptr); Xher<double2,double>(queue, nullptr); - Xhpr<float2,float>(queue, nullptr); Xhpr<double2,double>(queue, nullptr); - Xher2<float2>(queue, nullptr); Xher2<double2>(queue, nullptr); - Xhpr2<float2>(queue, nullptr); Xhpr2<double2>(queue, nullptr); - Xsyr<float>(queue, nullptr); Xsyr<double>(queue, nullptr); - Xspr<float>(queue, nullptr); Xspr<double>(queue, nullptr); - Xsyr2<float>(queue, nullptr); Xsyr2<double>(queue, nullptr); - Xspr2<float>(queue, nullptr); Xspr2<double>(queue, nullptr); - - // Runs all the level 3 set-up functions - Xgemm<float>(queue, nullptr); Xgemm<double>(queue, nullptr); Xgemm<float2>(queue, nullptr); Xgemm<double2>(queue, nullptr); - Xsymm<float>(queue, nullptr); Xsymm<double>(queue, nullptr); Xsymm<float2>(queue, nullptr); Xsymm<double2>(queue, nullptr); - Xhemm<float2>(queue, nullptr); Xhemm<double2>(queue, nullptr); - Xsyrk<float>(queue, nullptr); Xsyrk<double>(queue, nullptr); Xsyrk<float2>(queue, nullptr); Xsyrk<double2>(queue, nullptr); - Xherk<float2,float>(queue, nullptr); Xherk<double2,double>(queue, nullptr); - Xsyr2k<float>(queue, nullptr); Xsyr2k<double>(queue, nullptr); Xsyr2k<float2>(queue, nullptr); Xsyr2k<double2>(queue, nullptr); - Xher2k<float2,float>(queue, nullptr); Xher2k<double2,double>(queue, nullptr); - Xtrmm<float>(queue, nullptr); Xtrmm<double>(queue, nullptr); Xtrmm<float2>(queue, nullptr); Xtrmm<double2>(queue, nullptr); - - // Runs all the level 3 set-up functions - Xomatcopy<float>(queue, nullptr); Xomatcopy<double>(queue, nullptr); Xomatcopy<float2>(queue, nullptr); Xomatcopy<double2>(queue, nullptr); + FillCacheForPrecision<float, float2>(queue); + FillCacheForPrecision<double, double2>(queue); } catch (...) { return DispatchException(); } return StatusCode::kSuccess; |