summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/clblast.h11
-rwxr-xr-xscripts/generator/generator.py2
-rw-r--r--src/clblast.cpp53
-rw-r--r--src/routine.cpp20
-rw-r--r--src/routine.hpp20
-rw-r--r--src/routines/level3/xgemm.hpp34
6 files changed, 120 insertions, 20 deletions
diff --git a/include/clblast.h b/include/clblast.h
index e073b211..3318768a 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -647,6 +647,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
// =================================================================================================
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ cl_command_queue* queue, size_t& temp_buffer_size);
+
+// =================================================================================================
+
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
StatusCode PUBLIC_API ClearCache();
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 520e3fc8..6c671deb 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -47,7 +47,7 @@ FILES = [
"/src/clblast_cuda.cpp",
]
HEADER_LINES = [122, 21, 126, 24, 29, 41, 29, 65, 32, 94, 21]
-FOOTER_LINES = [25, 3, 27, 38, 6, 6, 6, 9, 2, 25, 3]
+FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 25, 3]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 7d2c2cef..e38a75ca 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2333,4 +2333,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
cl_command_queue*, cl_event*);
// =================================================================================================
+
+// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
+template <typename T>
+StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ RawCommandQueue* queue, size_t& temp_buffer_size) {
+ try {
+
+ // Retrieves the tuning database
+ const auto queue_cpp = Queue(*queue);
+ const auto device = queue_cpp.GetDevice();
+ const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
+ Databases db(kernel_names);
+ Routine::InitDatabase(device, kernel_names, PrecisionValue<T>(), {}, db);
+
+ // Computes the buffer size
+ if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
+ temp_buffer_size = 0;
+ }
+ else {
+ temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
+ a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
+ db["MWG"], db["NWG"], db["KWG"]);
+ }
+ temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
+ return StatusCode::kSuccess;
+ } catch (...) { return DispatchException(); }
+}
+template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, RawCommandQueue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, RawCommandQueue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, RawCommandQueue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, RawCommandQueue*, size_t&);
+template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
+ const size_t, const size_t, const size_t,
+ const size_t, const size_t, const size_t, const size_t,
+ const size_t, const size_t, RawCommandQueue*, size_t&);
+
+// =================================================================================================
} // namespace clblast
diff --git a/src/routine.cpp b/src/routine.cpp
index 5a1c0fe9..fa5934f6 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -62,28 +62,10 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
device_(queue_.GetDevice()),
db_(kernel_names) {
- InitDatabase(userDatabase);
+ InitDatabase(device_, kernel_names, precision, userDatabase, db_);
InitProgram(source);
}
-void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatabase) {
- const auto platform_id = device_.PlatformID();
- for (const auto &kernel_name : kernel_names_) {
-
- // Queries the cache to see whether or not the kernel parameter database is already there
- bool has_db;
- db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_id, device_(), precision_, kernel_name },
- &has_db);
- if (has_db) { continue; }
-
- // Builds the parameter database for this device and routine set and stores it in the cache
- log_debug("Searching database for kernel '" + kernel_name + "'");
- db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
- DatabaseCache::Instance().Store(DatabaseKey{ platform_id, device_(), precision_, kernel_name },
- Database{ db_(kernel_name) });
- }
-}
-
void Routine::InitProgram(std::initializer_list<const char *> source) {
// Determines the identifier for this particular routine call
diff --git a/src/routine.hpp b/src/routine.hpp
index a8f1cb6a..00f7b5cc 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -33,6 +33,26 @@ namespace clblast {
class Routine {
public:
+ static void InitDatabase(const Device &device, const std::vector<std::string> &kernel_names,
+ const Precision precision, const std::vector<database::DatabaseEntry> &userDatabase,
+ Databases &db) {
+ const auto platform_id = device.PlatformID();
+ for (const auto &kernel_name : kernel_names) {
+
+ // Queries the cache to see whether or not the kernel parameter database is already there
+ bool has_db;
+ db(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device(), precision, kernel_name},
+ &has_db);
+ if (has_db) { continue; }
+
+ // Builds the parameter database for this device and routine set and stores it in the cache
+ log_debug("Searching database for kernel '" + kernel_name + "'");
+ db(kernel_name) = Database(device, kernel_name, precision, userDatabase);
+ DatabaseCache::Instance().Store(DatabaseKey{platform_id, device(), precision, kernel_name},
+ Database{db(kernel_name)});
+ }
+ }
+
// Base class constructor. The user database is an optional extra database to override the
// built-in database.
// All heavy preparation work is done inside this constructor.
diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp
index f0911d6a..25b1f5c9 100644
--- a/src/routines/level3/xgemm.hpp
+++ b/src/routines/level3/xgemm.hpp
@@ -29,6 +29,40 @@ class Xgemm: public Routine {
static const bool b_want_rotated_;
static const bool c_want_rotated_;
+ // Computes the size of the temporary GEMM buffer based on user-arguments
+ static size_t GetTempSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
+ const size_t m, const size_t n, const size_t k,
+ const size_t a_offset, const size_t a_ld,
+ const size_t b_offset, const size_t b_ld,
+ const size_t c_offset, const size_t c_ld,
+ const size_t mwg, const size_t nwg, const size_t kwg) {
+
+ // Computes the transpose/conjugate options and sets the a/b/c sizes based on that
+ bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
+ size_t a_one, a_two, b_one, b_two, c_one, c_two;
+ ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
+ a_one, a_two, b_one, b_two, c_one, c_two,
+ a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
+
+ // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
+ // whether the matrices need to be rotated or not for the kernel.
+ size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
+ CalculateInternalDimensions(m, n, k, mwg, nwg, kwg,
+ a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
+
+ // Determines whether or not temporary matrices are needed
+ auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate);
+ auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate);
+ auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false);
+
+ // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
+ auto b_temp_offset = size_t{0};
+ auto c_temp_offset = size_t{0};
+ return ComputeTempSize(a_no_temp, b_no_temp, c_no_temp,
+ a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i,
+ b_temp_offset, c_temp_offset);
+ }
+
// Selects which version of GEMM to run
static bool UseDirectKernel(const size_t m, const size_t n, const size_t k,
const size_t min_indirect_size) {