summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/clblast.h4
-rw-r--r--include/clblast_c.h4
-rw-r--r--scripts/generator/generator.py2
-rw-r--r--src/clblast.cc62
-rw-r--r--src/clblast_c.cc5
5 files changed, 76 insertions, 1 deletions
diff --git a/include/clblast.h b/include/clblast.h
index e473adbe..075ca93e 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -558,6 +558,10 @@ StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, c
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
StatusCode ClearCache();
+// The cache can also be pre-initialized for a specific device with all possible CLBLast kernels.
+// Further CLBlast routine calls will then run at maximum speed.
+StatusCode FillCache(const cl_device_id device);
+
// =================================================================================================
} // namespace clblast
diff --git a/include/clblast_c.h b/include/clblast_c.h
index 45e50cff..dd9b0f67 100644
--- a/include/clblast_c.h
+++ b/include/clblast_c.h
@@ -1076,6 +1076,10 @@ StatusCode PUBLIC_API CLBlastZtrsm(const Layout layout, const Side side, const T
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
StatusCode PUBLIC_API CLBlastClearCache();
+// The cache can also be pre-initialized for a specific device with all possible CLBLast kernels.
+// Further CLBlast routine calls will then run at maximum speed.
+StatusCode PUBLIC_API CLBlastFillCache(const cl_device_id device);
+
// =================================================================================================
#ifdef __cplusplus
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 04f3c30e..a9419f13 100644
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -299,7 +299,7 @@ files = [
path_clblast+"/test/wrapper_cblas.h",
]
header_lines = [84, 70, 93, 22, 29, 38]
-footer_lines = [13, 8, 15, 9, 6, 6]
+footer_lines = [17, 70, 19, 14, 6, 6]
# Checks whether the command-line arguments are valid; exists otherwise
for f in files:
diff --git a/src/clblast.cc b/src/clblast.cc
index fe79d7c1..a5bb6b67 100644
--- a/src/clblast.cc
+++ b/src/clblast.cc
@@ -1857,5 +1857,67 @@ template StatusCode PUBLIC_API Trsm<double2>(const Layout, const Side, const Tri
// Clears the cache of stored binaries
StatusCode ClearCache() { return cache::ClearCache(); }
+// Fills the cache with all binaries for a specific device
+StatusCode FillCache(const cl_device_id device) {
+ try {
+
+ // Creates a sample context and queue to match the normal routine calling conventions
+ auto device_cpp = Device(device);
+ auto context = Context(device_cpp);
+ auto queue = Queue(context, device_cpp);
+
+ // Runs all the level 1 set-up functions
+ Xswap<float>(queue, nullptr).SetUp(); Xswap<double>(queue, nullptr).SetUp(); Xswap<float2>(queue, nullptr).SetUp(); Xswap<double2>(queue, nullptr).SetUp();
+ Xswap<float>(queue, nullptr).SetUp(); Xswap<double>(queue, nullptr).SetUp(); Xswap<float2>(queue, nullptr).SetUp(); Xswap<double2>(queue, nullptr).SetUp();
+ Xscal<float>(queue, nullptr).SetUp(); Xscal<double>(queue, nullptr).SetUp(); Xscal<float2>(queue, nullptr).SetUp(); Xscal<double2>(queue, nullptr).SetUp();
+ Xcopy<float>(queue, nullptr).SetUp(); Xcopy<double>(queue, nullptr).SetUp(); Xcopy<float2>(queue, nullptr).SetUp(); Xcopy<double2>(queue, nullptr).SetUp();
+ Xaxpy<float>(queue, nullptr).SetUp(); Xaxpy<double>(queue, nullptr).SetUp(); Xaxpy<float2>(queue, nullptr).SetUp(); Xaxpy<double2>(queue, nullptr).SetUp();
+ Xdot<float>(queue, nullptr).SetUp(); Xdot<double>(queue, nullptr).SetUp();
+ Xdotu<float2>(queue, nullptr).SetUp(); Xdotu<double2>(queue, nullptr).SetUp();
+ Xdotc<float2>(queue, nullptr).SetUp(); Xdotc<double2>(queue, nullptr).SetUp();
+ Xnrm2<float>(queue, nullptr).SetUp(); Xnrm2<double>(queue, nullptr).SetUp(); Xnrm2<float2>(queue, nullptr).SetUp(); Xnrm2<double2>(queue, nullptr).SetUp();
+ Xasum<float>(queue, nullptr).SetUp(); Xasum<double>(queue, nullptr).SetUp(); Xasum<float2>(queue, nullptr).SetUp(); Xasum<double2>(queue, nullptr).SetUp();
+ Xsum<float>(queue, nullptr).SetUp(); Xsum<double>(queue, nullptr).SetUp(); Xsum<float2>(queue, nullptr).SetUp(); Xsum<double2>(queue, nullptr).SetUp();
+ Xamax<float>(queue, nullptr).SetUp(); Xamax<double>(queue, nullptr).SetUp(); Xamax<float2>(queue, nullptr).SetUp(); Xamax<double2>(queue, nullptr).SetUp();
+ Xmax<float>(queue, nullptr).SetUp(); Xmax<double>(queue, nullptr).SetUp(); Xmax<float2>(queue, nullptr).SetUp(); Xmax<double2>(queue, nullptr).SetUp();
+
+ // Runs all the level 2 set-up functions
+ Xgemv<float>(queue, nullptr).SetUp(); Xgemv<double>(queue, nullptr).SetUp(); Xgemv<float2>(queue, nullptr).SetUp(); Xgemv<double2>(queue, nullptr).SetUp();
+ Xgbmv<float>(queue, nullptr).SetUp(); Xgbmv<double>(queue, nullptr).SetUp(); Xgbmv<float2>(queue, nullptr).SetUp(); Xgbmv<double2>(queue, nullptr).SetUp();
+ Xhemv<float2>(queue, nullptr).SetUp(); Xhemv<double2>(queue, nullptr).SetUp();
+ Xhbmv<float2>(queue, nullptr).SetUp(); Xhbmv<double2>(queue, nullptr).SetUp();
+ Xhpmv<float2>(queue, nullptr).SetUp(); Xhpmv<double2>(queue, nullptr).SetUp();
+ Xsymv<float>(queue, nullptr).SetUp(); Xsymv<double>(queue, nullptr).SetUp();
+ Xsbmv<float>(queue, nullptr).SetUp(); Xsbmv<double>(queue, nullptr).SetUp();
+ Xspmv<float>(queue, nullptr).SetUp(); Xspmv<double>(queue, nullptr).SetUp();
+ Xtrmv<float>(queue, nullptr).SetUp(); Xtrmv<double>(queue, nullptr).SetUp(); Xtrmv<float2>(queue, nullptr).SetUp(); Xtrmv<double2>(queue, nullptr).SetUp();
+ Xtbmv<float>(queue, nullptr).SetUp(); Xtbmv<double>(queue, nullptr).SetUp(); Xtbmv<float2>(queue, nullptr).SetUp(); Xtbmv<double2>(queue, nullptr).SetUp();
+ Xtpmv<float>(queue, nullptr).SetUp(); Xtpmv<double>(queue, nullptr).SetUp(); Xtpmv<float2>(queue, nullptr).SetUp(); Xtpmv<double2>(queue, nullptr).SetUp();
+ Xger<float>(queue, nullptr).SetUp(); Xger<double>(queue, nullptr).SetUp();
+ Xgeru<float2>(queue, nullptr).SetUp(); Xgeru<double2>(queue, nullptr).SetUp();
+ Xgerc<float2>(queue, nullptr).SetUp(); Xgerc<double2>(queue, nullptr).SetUp();
+ Xher<float2,float>(queue, nullptr).SetUp(); Xher<double2,double>(queue, nullptr).SetUp();
+ Xhpr<float2,float>(queue, nullptr).SetUp(); Xhpr<double2,double>(queue, nullptr).SetUp();
+ Xher2<float2>(queue, nullptr).SetUp(); Xher2<double2>(queue, nullptr).SetUp();
+ Xhpr2<float2>(queue, nullptr).SetUp(); Xhpr2<double2>(queue, nullptr).SetUp();
+ Xsyr<float>(queue, nullptr).SetUp(); Xsyr<double>(queue, nullptr).SetUp();
+ Xspr<float>(queue, nullptr).SetUp(); Xspr<double>(queue, nullptr).SetUp();
+ Xsyr2<float>(queue, nullptr).SetUp(); Xsyr2<double>(queue, nullptr).SetUp();
+ Xspr2<float>(queue, nullptr).SetUp(); Xspr2<double>(queue, nullptr).SetUp();
+
+ // Runs all the level 1 set-up functions
+ Xgemm<float>(queue, nullptr).SetUp(); Xgemm<double>(queue, nullptr).SetUp(); Xgemm<float2>(queue, nullptr).SetUp(); Xgemm<double2>(queue, nullptr).SetUp();
+ Xsymm<float>(queue, nullptr).SetUp(); Xsymm<double>(queue, nullptr).SetUp(); Xsymm<float2>(queue, nullptr).SetUp(); Xsymm<double2>(queue, nullptr).SetUp();
+ Xhemm<float2>(queue, nullptr).SetUp(); Xhemm<double2>(queue, nullptr).SetUp();
+ Xsyrk<float>(queue, nullptr).SetUp(); Xsyrk<double>(queue, nullptr).SetUp(); Xsyrk<float2>(queue, nullptr).SetUp(); Xsyrk<double2>(queue, nullptr).SetUp();
+ Xherk<float2,float>(queue, nullptr).SetUp(); Xherk<double2,double>(queue, nullptr).SetUp();
+ Xsyr2k<float>(queue, nullptr).SetUp(); Xsyr2k<double>(queue, nullptr).SetUp(); Xsyr2k<float2>(queue, nullptr).SetUp(); Xsyr2k<double2>(queue, nullptr).SetUp();
+ Xher2k<float2,float>(queue, nullptr).SetUp(); Xher2k<double2,double>(queue, nullptr).SetUp();
+ Xtrmm<float>(queue, nullptr).SetUp(); Xtrmm<double>(queue, nullptr).SetUp(); Xtrmm<float2>(queue, nullptr).SetUp(); Xtrmm<double2>(queue, nullptr).SetUp();
+
+ } catch (...) { return StatusCode::kBuildProgramFailure; }
+ return StatusCode::kSuccess;
+}
+
// =================================================================================================
} // namespace clblast
diff --git a/src/clblast_c.cc b/src/clblast_c.cc
index 172bce64..47ab1798 100644
--- a/src/clblast_c.cc
+++ b/src/clblast_c.cc
@@ -2348,4 +2348,9 @@ StatusCode CLBlastClearCache() {
return static_cast<StatusCode>(clblast::ClearCache());
}
+// Fills the cache with binaries for a specific device
+StatusCode CLBlastFillCache(const cl_device_id device) {
+ return static_cast<StatusCode>(clblast::FillCache(device));
+}
+
// =================================================================================================