diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-12-28 14:46:45 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-12-28 14:46:45 +0100 |
commit | 6d1e30e61f5ef73f0a83e12f064cae64644034ca (patch) | |
tree | a8874ae7db89acbce71ccf2be560dd13309012b4 | |
parent | aaea9474a154a9f07534523e4ca66e4b2c5f2d4f (diff) |
Added interface to compute the required temporary buffer size for GEMM
-rw-r--r-- | include/clblast.h | 11 | ||||
-rwxr-xr-x | scripts/generator/generator.py | 2 | ||||
-rw-r--r-- | src/clblast.cpp | 53 | ||||
-rw-r--r-- | src/routine.cpp | 20 | ||||
-rw-r--r-- | src/routine.hpp | 20 | ||||
-rw-r--r-- | src/routines/level3/xgemm.hpp | 34 |
6 files changed, 120 insertions, 20 deletions
diff --git a/include/clblast.h b/include/clblast.h index e073b211..3318768a 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -647,6 +647,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T // ================================================================================================= +// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) +template <typename T> +StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, size_t& temp_buffer_size); + +// ================================================================================================= + // CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on // for the same device. This cache can be cleared to free up system memory or in case of debugging. StatusCode PUBLIC_API ClearCache(); diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 520e3fc8..6c671deb 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -47,7 +47,7 @@ FILES = [ "/src/clblast_cuda.cpp", ] HEADER_LINES = [122, 21, 126, 24, 29, 41, 29, 65, 32, 94, 21] -FOOTER_LINES = [25, 3, 27, 38, 6, 6, 6, 9, 2, 25, 3] +FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 25, 3] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 63 diff --git a/src/clblast.cpp b/src/clblast.cpp index 7d2c2cef..e38a75ca 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2333,4 +2333,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, cl_command_queue*, cl_event*); // ================================================================================================= + +// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) +template <typename T> +StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + RawCommandQueue* queue, size_t& temp_buffer_size) { + try { + + // Retrieves the tuning database + const auto queue_cpp = Queue(*queue); + const auto device = queue_cpp.GetDevice(); + const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"}; + Databases db(kernel_names); + Routine::InitDatabase(device, kernel_names, PrecisionValue<T>(), {}, db); + + // Computes the buffer size + if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) { + temp_buffer_size = 0; + } + else { + temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k, + a_offset, a_ld, b_offset, b_ld, c_offset, c_ld, + db["MWG"], db["NWG"], db["KWG"]); + } + temp_buffer_size *= sizeof(T); // translate from num-elements to bytes + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, RawCommandQueue*, size_t&); + +// ================================================================================================= } // namespace clblast diff --git a/src/routine.cpp b/src/routine.cpp index 5a1c0fe9..fa5934f6 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -62,28 +62,10 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, device_(queue_.GetDevice()), db_(kernel_names) { - InitDatabase(userDatabase); + InitDatabase(device_, kernel_names, precision, userDatabase, db_); InitProgram(source); } -void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatabase) { - const auto platform_id = device_.PlatformID(); - for (const auto &kernel_name : kernel_names_) { - - // Queries the cache to see whether or not the kernel parameter database is already there - bool has_db; - db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_id, device_(), precision_, kernel_name }, - &has_db); - if (has_db) { continue; } - - // Builds the parameter database for this device and routine set and stores it in the cache - log_debug("Searching database for kernel '" + kernel_name + "'"); - db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase); - DatabaseCache::Instance().Store(DatabaseKey{ platform_id, device_(), precision_, kernel_name }, - Database{ db_(kernel_name) }); - } -} - void Routine::InitProgram(std::initializer_list<const char *> source) { // Determines the identifier for this particular routine call diff --git a/src/routine.hpp b/src/routine.hpp index a8f1cb6a..00f7b5cc 100644 --- a/src/routine.hpp +++ b/src/routine.hpp @@ -33,6 +33,26 @@ namespace clblast { class Routine { public: + static void InitDatabase(const Device &device, const std::vector<std::string> &kernel_names, + const Precision precision, const std::vector<database::DatabaseEntry> &userDatabase, + Databases &db) { + const auto platform_id = device.PlatformID(); + for (const auto &kernel_name : kernel_names) { + + // Queries the cache to see whether or not the kernel parameter database is already there + bool has_db; + db(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device(), precision, kernel_name}, + &has_db); + if (has_db) { continue; } + + // Builds the parameter database for this device and routine set and stores it in the cache + log_debug("Searching database for kernel '" + kernel_name + "'"); + db(kernel_name) = Database(device, kernel_name, precision, userDatabase); + DatabaseCache::Instance().Store(DatabaseKey{platform_id, device(), precision, kernel_name}, + Database{db(kernel_name)}); + } + } + // Base class constructor. The user database is an optional extra database to override the // built-in database. // All heavy preparation work is done inside this constructor. diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index f0911d6a..25b1f5c9 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -29,6 +29,40 @@ class Xgemm: public Routine { static const bool b_want_rotated_; static const bool c_want_rotated_; + // Computes the size of the temporary GEMM buffer based on user-arguments + static size_t GetTempSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + const size_t mwg, const size_t nwg, const size_t kwg) { + + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that + bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; + size_t a_one, a_two, b_one, b_two, c_one, c_two; + ProcessArguments(layout, a_transpose, b_transpose, m, n, k, + a_one, a_two, b_one, b_two, c_one, c_two, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + + // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account + // whether the matrices need to be rotated or not for the kernel. + size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; + CalculateInternalDimensions(m, n, k, mwg, nwg, kwg, + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + + // Determines whether or not temporary matrices are needed + auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate); + auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate); + auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false); + + // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices + auto b_temp_offset = size_t{0}; + auto c_temp_offset = size_t{0}; + return ComputeTempSize(a_no_temp, b_no_temp, c_no_temp, + a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i, + b_temp_offset, c_temp_offset); + } + // Selects which version of GEMM to run static bool UseDirectKernel(const size_t m, const size_t n, const size_t k, const size_t min_indirect_size) { |