summaryrefslogtreecommitdiff
path: root/src/clblast.cpp
diff options
context:
space:
mode:
authorIvan Shapovalov <intelfx@intelfx.name>2016-11-28 03:03:24 +0300
committerIvan Shapovalov <intelfx@intelfx.name>2017-01-24 02:43:00 +0300
commit1b8e81633357004767747893cf3e179ed4f53dcc (patch)
tree516e799d0a94629f28b29dde189907ebca4b465e /src/clblast.cpp
parent6ad11665a15afd53b4d1243b0319820e3340826d (diff)
FillCache: perform compilation for each precision separately
Thus do not prevent filling cache for float if the device does not support e. g. double.
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r--src/clblast.cpp119
1 files changed, 67 insertions, 52 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 6a47316e..e0f8add2 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2170,6 +2170,71 @@ StatusCode ClearCache() {
return StatusCode::kSuccess;
}
+template <typename Real, typename Complex>
+void FillCacheForPrecision(Queue &queue) {
+ try {
+
+ // Runs all the level 1 set-up functions
+ Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
+ Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
+ Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr);
+ Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr);
+ Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr);
+ Xdot<Real>(queue, nullptr);
+ Xdotu<Complex>(queue, nullptr);
+ Xdotc<Complex>(queue, nullptr);
+ Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr);
+ Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr);
+ Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr);
+ Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr);
+ Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr);
+ Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr);
+
+ // Runs all the level 2 set-up functions
+ Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr);
+ Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr);
+ Xhemv<Complex>(queue, nullptr);
+ Xhbmv<Complex>(queue, nullptr);
+ Xhpmv<Complex>(queue, nullptr);
+ Xsymv<Real>(queue, nullptr);
+ Xsbmv<Real>(queue, nullptr);
+ Xspmv<Real>(queue, nullptr);
+ Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr);
+ Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr);
+ Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr);
+ Xger<Real>(queue, nullptr);
+ Xgeru<Complex>(queue, nullptr);
+ Xgerc<Complex>(queue, nullptr);
+ Xher<Complex,Real>(queue, nullptr);
+ Xhpr<Complex,Real>(queue, nullptr);
+ Xher2<Complex>(queue, nullptr);
+ Xhpr2<Complex>(queue, nullptr);
+ Xsyr<Real>(queue, nullptr);
+ Xspr<Real>(queue, nullptr);
+ Xsyr2<Real>(queue, nullptr);
+ Xspr2<Real>(queue, nullptr);
+
+ // Runs all the level 3 set-up functions
+ Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr);
+ Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr);
+ Xhemm<Complex>(queue, nullptr);
+ Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr);
+ Xherk<Complex,Real>(queue, nullptr);
+ Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr);
+ Xher2k<Complex,Real>(queue, nullptr);
+ Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr);
+
+ // Runs all the non-BLAS set-up functions
+ Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr);
+
+ } catch(const RuntimeErrorCode &e) {
+ if (e.status() != StatusCode::kNoDoublePrecision &&
+ e.status() != StatusCode::kNoHalfPrecision) {
+ throw;
+ }
+ }
+}
+
// Fills the cache with all binaries for a specific device
// TODO: Add half-precision FP16 set-up calls
StatusCode FillCache(const cl_device_id device) {
@@ -2180,58 +2245,8 @@ StatusCode FillCache(const cl_device_id device) {
auto context = Context(device_cpp);
auto queue = Queue(context, device_cpp);
- // Runs all the level 1 set-up functions
- Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr);
- Xswap<float>(queue, nullptr); Xswap<double>(queue, nullptr); Xswap<float2>(queue, nullptr); Xswap<double2>(queue, nullptr);
- Xscal<float>(queue, nullptr); Xscal<double>(queue, nullptr); Xscal<float2>(queue, nullptr); Xscal<double2>(queue, nullptr);
- Xcopy<float>(queue, nullptr); Xcopy<double>(queue, nullptr); Xcopy<float2>(queue, nullptr); Xcopy<double2>(queue, nullptr);
- Xaxpy<float>(queue, nullptr); Xaxpy<double>(queue, nullptr); Xaxpy<float2>(queue, nullptr); Xaxpy<double2>(queue, nullptr);
- Xdot<float>(queue, nullptr); Xdot<double>(queue, nullptr);
- Xdotu<float2>(queue, nullptr); Xdotu<double2>(queue, nullptr);
- Xdotc<float2>(queue, nullptr); Xdotc<double2>(queue, nullptr);
- Xnrm2<float>(queue, nullptr); Xnrm2<double>(queue, nullptr); Xnrm2<float2>(queue, nullptr); Xnrm2<double2>(queue, nullptr);
- Xasum<float>(queue, nullptr); Xasum<double>(queue, nullptr); Xasum<float2>(queue, nullptr); Xasum<double2>(queue, nullptr);
- Xsum<float>(queue, nullptr); Xsum<double>(queue, nullptr); Xsum<float2>(queue, nullptr); Xsum<double2>(queue, nullptr);
- Xamax<float>(queue, nullptr); Xamax<double>(queue, nullptr); Xamax<float2>(queue, nullptr); Xamax<double2>(queue, nullptr);
- Xmax<float>(queue, nullptr); Xmax<double>(queue, nullptr); Xmax<float2>(queue, nullptr); Xmax<double2>(queue, nullptr);
- Xmin<float>(queue, nullptr); Xmin<double>(queue, nullptr); Xmin<float2>(queue, nullptr); Xmin<double2>(queue, nullptr);
-
- // Runs all the level 2 set-up functions
- Xgemv<float>(queue, nullptr); Xgemv<double>(queue, nullptr); Xgemv<float2>(queue, nullptr); Xgemv<double2>(queue, nullptr);
- Xgbmv<float>(queue, nullptr); Xgbmv<double>(queue, nullptr); Xgbmv<float2>(queue, nullptr); Xgbmv<double2>(queue, nullptr);
- Xhemv<float2>(queue, nullptr); Xhemv<double2>(queue, nullptr);
- Xhbmv<float2>(queue, nullptr); Xhbmv<double2>(queue, nullptr);
- Xhpmv<float2>(queue, nullptr); Xhpmv<double2>(queue, nullptr);
- Xsymv<float>(queue, nullptr); Xsymv<double>(queue, nullptr);
- Xsbmv<float>(queue, nullptr); Xsbmv<double>(queue, nullptr);
- Xspmv<float>(queue, nullptr); Xspmv<double>(queue, nullptr);
- Xtrmv<float>(queue, nullptr); Xtrmv<double>(queue, nullptr); Xtrmv<float2>(queue, nullptr); Xtrmv<double2>(queue, nullptr);
- Xtbmv<float>(queue, nullptr); Xtbmv<double>(queue, nullptr); Xtbmv<float2>(queue, nullptr); Xtbmv<double2>(queue, nullptr);
- Xtpmv<float>(queue, nullptr); Xtpmv<double>(queue, nullptr); Xtpmv<float2>(queue, nullptr); Xtpmv<double2>(queue, nullptr);
- Xger<float>(queue, nullptr); Xger<double>(queue, nullptr);
- Xgeru<float2>(queue, nullptr); Xgeru<double2>(queue, nullptr);
- Xgerc<float2>(queue, nullptr); Xgerc<double2>(queue, nullptr);
- Xher<float2,float>(queue, nullptr); Xher<double2,double>(queue, nullptr);
- Xhpr<float2,float>(queue, nullptr); Xhpr<double2,double>(queue, nullptr);
- Xher2<float2>(queue, nullptr); Xher2<double2>(queue, nullptr);
- Xhpr2<float2>(queue, nullptr); Xhpr2<double2>(queue, nullptr);
- Xsyr<float>(queue, nullptr); Xsyr<double>(queue, nullptr);
- Xspr<float>(queue, nullptr); Xspr<double>(queue, nullptr);
- Xsyr2<float>(queue, nullptr); Xsyr2<double>(queue, nullptr);
- Xspr2<float>(queue, nullptr); Xspr2<double>(queue, nullptr);
-
- // Runs all the level 3 set-up functions
- Xgemm<float>(queue, nullptr); Xgemm<double>(queue, nullptr); Xgemm<float2>(queue, nullptr); Xgemm<double2>(queue, nullptr);
- Xsymm<float>(queue, nullptr); Xsymm<double>(queue, nullptr); Xsymm<float2>(queue, nullptr); Xsymm<double2>(queue, nullptr);
- Xhemm<float2>(queue, nullptr); Xhemm<double2>(queue, nullptr);
- Xsyrk<float>(queue, nullptr); Xsyrk<double>(queue, nullptr); Xsyrk<float2>(queue, nullptr); Xsyrk<double2>(queue, nullptr);
- Xherk<float2,float>(queue, nullptr); Xherk<double2,double>(queue, nullptr);
- Xsyr2k<float>(queue, nullptr); Xsyr2k<double>(queue, nullptr); Xsyr2k<float2>(queue, nullptr); Xsyr2k<double2>(queue, nullptr);
- Xher2k<float2,float>(queue, nullptr); Xher2k<double2,double>(queue, nullptr);
- Xtrmm<float>(queue, nullptr); Xtrmm<double>(queue, nullptr); Xtrmm<float2>(queue, nullptr); Xtrmm<double2>(queue, nullptr);
-
- // Runs all the level 3 set-up functions
- Xomatcopy<float>(queue, nullptr); Xomatcopy<double>(queue, nullptr); Xomatcopy<float2>(queue, nullptr); Xomatcopy<double2>(queue, nullptr);
+ FillCacheForPrecision<float, float2>(queue);
+ FillCacheForPrecision<double, double2>(queue);
} catch (...) { return DispatchException(); }
return StatusCode::kSuccess;