diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/clblast.cpp | 71 | ||||
-rw-r--r-- | src/clblast_cuda.cpp | 70 | ||||
-rw-r--r-- | src/clpp11.hpp | 9 | ||||
-rw-r--r-- | src/cupp11.hpp | 6 | ||||
-rw-r--r-- | src/database/database.cpp | 38 | ||||
-rw-r--r-- | src/database/database.hpp | 2 | ||||
-rw-r--r-- | src/routine.cpp | 20 | ||||
-rw-r--r-- | src/routine.hpp | 20 | ||||
-rw-r--r-- | src/routines/level3/xgemm.cpp | 107 | ||||
-rw-r--r-- | src/routines/level3/xgemm.hpp | 134 |
10 files changed, 350 insertions, 127 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index 7d2c2cef..f5e2f1be 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -1651,17 +1651,21 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const T beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, - cl_command_queue* queue, cl_event* event) { + cl_command_queue* queue, cl_event* event, + cl_mem temp_buffer) { try { auto queue_cpp = Queue(*queue); auto routine = Xgemm<T>(queue_cpp, event); + const auto temp_buffer_provided = temp_buffer != nullptr; + auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr); routine.DoGemm(layout, a_transpose, b_transpose, m, n, k, alpha, Buffer<T>(a_buffer), a_offset, a_ld, Buffer<T>(b_buffer), b_offset, b_ld, beta, - Buffer<T>(c_buffer), c_offset, c_ld); + Buffer<T>(c_buffer), c_offset, c_ld, + temp_buffer_cpp, temp_buffer_provided); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } @@ -1672,7 +1676,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const float, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double, @@ -1680,7 +1684,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const double, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float2, @@ -1688,7 +1692,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const const cl_mem, const size_t, const size_t, const float2, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double2, @@ -1696,7 +1700,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons const cl_mem, const size_t, const size_t, const double2, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const half, @@ -1704,7 +1708,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T const cl_mem, const size_t, const size_t, const half, cl_mem, const size_t, const size_t, - cl_command_queue*, cl_event*); + cl_command_queue*, cl_event*, cl_mem); // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template <typename T> @@ -2333,4 +2337,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, cl_command_queue*, cl_event*); // ================================================================================================= + +// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) +template <typename T> +StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + cl_command_queue* queue, size_t& temp_buffer_size) { + try { + + // Retrieves the tuning database + const auto queue_cpp = Queue(*queue); + const auto device = queue_cpp.GetDevice(); + const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"}; + Databases db(kernel_names); + Routine::InitDatabase(device, kernel_names, PrecisionValue<T>(), {}, db); + + // Computes the buffer size + if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) { + temp_buffer_size = 0; + } + else { + temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k, + a_offset, a_ld, b_offset, b_ld, c_offset, c_ld, + db["MWG"], db["NWG"], db["KWG"]); + } + temp_buffer_size *= sizeof(T); // translate from num-elements to bytes + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, cl_command_queue*, size_t&); + +// ================================================================================================= } // namespace clblast diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index 0e3d949d..348ff3f5 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -1725,19 +1725,23 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, - const CUcontext context, const CUdevice device) { + const CUcontext context, const CUdevice device, + CUdeviceptr temp_buffer) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xgemm<T>(queue_cpp, nullptr); + const auto temp_buffer_provided = temp_buffer != 0; + auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(0); routine.DoGemm(layout, a_transpose, b_transpose, m, n, k, alpha, Buffer<T>(a_buffer), a_offset, a_ld, Buffer<T>(b_buffer), b_offset, b_ld, beta, - Buffer<T>(c_buffer), c_offset, c_ld); + Buffer<T>(c_buffer), c_offset, c_ld, + temp_buffer_cpp, temp_buffer_provided); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } @@ -1748,7 +1752,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double, @@ -1756,7 +1760,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float2, @@ -1764,7 +1768,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double2, @@ -1772,7 +1776,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const half, @@ -1780,7 +1784,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template <typename T> @@ -2433,4 +2437,56 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, const CUcontext, const CUdevice); // ================================================================================================= + +// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) +template <typename T> +StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + const CUdevice device, size_t& temp_buffer_size) { + try { + + // Retrieves the tuning database + const auto device_cpp = Device(device); + const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"}; + Databases db(kernel_names); + Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue<T>(), {}, db); + + // Computes the buffer size + if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) { + temp_buffer_size = 0; + } + else { + temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k, + a_offset, a_ld, b_offset, b_ld, c_offset, c_ld, + db["MWG"], db["NWG"], db["KWG"]); + } + temp_buffer_size *= sizeof(T); // translate from num-elements to bytes + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); + +// ================================================================================================= } // namespace clblast diff --git a/src/clpp11.hpp b/src/clpp11.hpp index 6ebf1322..2119f26b 100644 --- a/src/clpp11.hpp +++ b/src/clpp11.hpp @@ -614,10 +614,11 @@ class Buffer { } // Regular constructor with memory management. If this class does not own the buffer object, then - // the memory will not be freed automatically afterwards. + // the memory will not be freed automatically afterwards. If the size is set to 0, this will + // become a stub containing a nullptr explicit Buffer(const Context &context, const BufferAccess access, const size_t size): - buffer_(new cl_mem, [access](cl_mem* m) { - if (access != BufferAccess::kNotOwned) { CheckError(clReleaseMemObject(*m)); } + buffer_(new cl_mem, [access, size](cl_mem* m) { + if (access != BufferAccess::kNotOwned && size > 0) { CheckError(clReleaseMemObject(*m)); } delete m; }), access_(access) { @@ -625,7 +626,7 @@ class Buffer { if (access_ == BufferAccess::kReadOnly) { flags = CL_MEM_READ_ONLY; } if (access_ == BufferAccess::kWriteOnly) { flags = CL_MEM_WRITE_ONLY; } auto status = CL_SUCCESS; - *buffer_ = clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status); + *buffer_ = (size > 0) ? clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status) : nullptr; CLCudaAPIError::Check(status, "clCreateBuffer"); } diff --git a/src/cupp11.hpp b/src/cupp11.hpp index eb177ca2..509ae3e8 100644 --- a/src/cupp11.hpp +++ b/src/cupp11.hpp @@ -549,12 +549,12 @@ public: // Regular constructor with memory management. If this class does not own the buffer object, then // the memory will not be freed automatically afterwards. explicit Buffer(const Context &, const BufferAccess access, const size_t size): - buffer_(new CUdeviceptr, [access](CUdeviceptr* m) { - if (access != BufferAccess::kNotOwned) { CheckError(cuMemFree(*m)); } + buffer_(new CUdeviceptr, [access, size](CUdeviceptr* m) { + if (access != BufferAccess::kNotOwned && size > 0) { CheckError(cuMemFree(*m)); } delete m; }), access_(access) { - CheckError(cuMemAlloc(buffer_.get(), size*sizeof(T))); + if (size > 0) { CheckError(cuMemAlloc(buffer_.get(), size*sizeof(T))); } } // As above, but now with read/write access as a default diff --git a/src/database/database.cpp b/src/database/database.cpp index ed56c65d..b2f70e49 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -39,6 +39,7 @@ namespace clblast { // ================================================================================================= +std::vector<database::DatabaseEntry> Database::database = std::vector<database::DatabaseEntry>{}; const std::vector<database::DatabaseEntry> Database::apple_cpu_fallback = std::vector<database::DatabaseEntry>{ database::XaxpyApple, database::XdotApple, database::XgemvApple, database::XgemvFastApple, database::XgemvFastRotApple, database::XgerApple, database::XtrsvApple, @@ -58,23 +59,26 @@ Database::Database(const Device &device, const std::string &kernel_name, const Precision precision, const std::vector<database::DatabaseEntry> &overlay): parameters_(std::make_shared<database::Parameters>()) { - database = std::vector<database::DatabaseEntry>{ - database::XaxpyHalf, database::XaxpySingle, database::XaxpyDouble, database::XaxpyComplexSingle, database::XaxpyComplexDouble, - database::XdotHalf, database::XdotSingle, database::XdotDouble, database::XdotComplexSingle, database::XdotComplexDouble, - database::XgemvHalf, database::XgemvSingle, database::XgemvDouble, database::XgemvComplexSingle, database::XgemvComplexDouble, - database::XgemvFastHalf, database::XgemvFastSingle, database::XgemvFastDouble, database::XgemvFastComplexSingle, database::XgemvFastComplexDouble, - database::XgemvFastRotHalf, database::XgemvFastRotSingle, database::XgemvFastRotDouble, database::XgemvFastRotComplexSingle, database::XgemvFastRotComplexDouble, - database::XgerHalf, database::XgerSingle, database::XgerDouble, database::XgerComplexSingle, database::XgerComplexDouble, - database::XgemmHalf, database::XgemmSingle, database::XgemmDouble, database::XgemmComplexSingle, database::XgemmComplexDouble, - database::XgemmDirectHalf, database::XgemmDirectSingle, database::XgemmDirectDouble, database::XgemmDirectComplexSingle, database::XgemmDirectComplexDouble, - database::CopyHalf, database::CopySingle, database::CopyDouble, database::CopyComplexSingle, database::CopyComplexDouble, - database::PadHalf, database::PadSingle, database::PadDouble, database::PadComplexSingle, database::PadComplexDouble, - database::TransposeHalf, database::TransposeSingle, database::TransposeDouble, database::TransposeComplexSingle, database::TransposeComplexDouble, - database::PadtransposeHalf, database::PadtransposeSingle, database::PadtransposeDouble, database::PadtransposeComplexSingle, database::PadtransposeComplexDouble, - database::InvertHalf, database::InvertSingle, database::InvertDouble, database::InvertComplexSingle, database::InvertComplexDouble, - database::GemmRoutineHalf, database::GemmRoutineSingle, database::GemmRoutineDouble, database::GemmRoutineComplexSingle, database::GemmRoutineComplexDouble, - database::TrsvRoutineHalf, database::TrsvRoutineSingle, database::TrsvRoutineDouble, database::TrsvRoutineComplexSingle, database::TrsvRoutineComplexDouble - }; + // Initializes the static variable on first use. At this point we are sure all global variables are initialized + if (database.size() == 0) { + database = std::vector<database::DatabaseEntry>{ + database::XaxpyHalf, database::XaxpySingle, database::XaxpyDouble, database::XaxpyComplexSingle, database::XaxpyComplexDouble, + database::XdotHalf, database::XdotSingle, database::XdotDouble, database::XdotComplexSingle, database::XdotComplexDouble, + database::XgemvHalf, database::XgemvSingle, database::XgemvDouble, database::XgemvComplexSingle, database::XgemvComplexDouble, + database::XgemvFastHalf, database::XgemvFastSingle, database::XgemvFastDouble, database::XgemvFastComplexSingle, database::XgemvFastComplexDouble, + database::XgemvFastRotHalf, database::XgemvFastRotSingle, database::XgemvFastRotDouble, database::XgemvFastRotComplexSingle, database::XgemvFastRotComplexDouble, + database::XgerHalf, database::XgerSingle, database::XgerDouble, database::XgerComplexSingle, database::XgerComplexDouble, + database::XgemmHalf, database::XgemmSingle, database::XgemmDouble, database::XgemmComplexSingle, database::XgemmComplexDouble, + database::XgemmDirectHalf, database::XgemmDirectSingle, database::XgemmDirectDouble, database::XgemmDirectComplexSingle, database::XgemmDirectComplexDouble, + database::CopyHalf, database::CopySingle, database::CopyDouble, database::CopyComplexSingle, database::CopyComplexDouble, + database::PadHalf, database::PadSingle, database::PadDouble, database::PadComplexSingle, database::PadComplexDouble, + database::TransposeHalf, database::TransposeSingle, database::TransposeDouble, database::TransposeComplexSingle, database::TransposeComplexDouble, + database::PadtransposeHalf, database::PadtransposeSingle, database::PadtransposeDouble, database::PadtransposeComplexSingle, database::PadtransposeComplexDouble, + database::InvertHalf, database::InvertSingle, database::InvertDouble, database::InvertComplexSingle, database::InvertComplexDouble, + database::GemmRoutineHalf, database::GemmRoutineSingle, database::GemmRoutineDouble, database::GemmRoutineComplexSingle, database::GemmRoutineComplexDouble, + database::TrsvRoutineHalf, database::TrsvRoutineSingle, database::TrsvRoutineDouble, database::TrsvRoutineComplexSingle, database::TrsvRoutineComplexDouble + }; + } // Finds device information const auto device_type = GetDeviceType(device); diff --git a/src/database/database.hpp b/src/database/database.hpp index de4306bc..8e53e013 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -35,7 +35,7 @@ class Database { static const std::string kDeviceVendorAll; // The database consists of separate database entries, stored together in a vector - std::vector<database::DatabaseEntry> database; + static std::vector<database::DatabaseEntry> database; // Database for a special case: Apple CPUs support limited number of threads static const std::vector<database::DatabaseEntry> apple_cpu_fallback; diff --git a/src/routine.cpp b/src/routine.cpp index 5a1c0fe9..fa5934f6 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -62,28 +62,10 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, device_(queue_.GetDevice()), db_(kernel_names) { - InitDatabase(userDatabase); + InitDatabase(device_, kernel_names, precision, userDatabase, db_); InitProgram(source); } -void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatabase) { - const auto platform_id = device_.PlatformID(); - for (const auto &kernel_name : kernel_names_) { - - // Queries the cache to see whether or not the kernel parameter database is already there - bool has_db; - db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_id, device_(), precision_, kernel_name }, - &has_db); - if (has_db) { continue; } - - // Builds the parameter database for this device and routine set and stores it in the cache - log_debug("Searching database for kernel '" + kernel_name + "'"); - db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase); - DatabaseCache::Instance().Store(DatabaseKey{ platform_id, device_(), precision_, kernel_name }, - Database{ db_(kernel_name) }); - } -} - void Routine::InitProgram(std::initializer_list<const char *> source) { // Determines the identifier for this particular routine call diff --git a/src/routine.hpp b/src/routine.hpp index a8f1cb6a..00f7b5cc 100644 --- a/src/routine.hpp +++ b/src/routine.hpp @@ -33,6 +33,26 @@ namespace clblast { class Routine { public: + static void InitDatabase(const Device &device, const std::vector<std::string> &kernel_names, + const Precision precision, const std::vector<database::DatabaseEntry> &userDatabase, + Databases &db) { + const auto platform_id = device.PlatformID(); + for (const auto &kernel_name : kernel_names) { + + // Queries the cache to see whether or not the kernel parameter database is already there + bool has_db; + db(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device(), precision, kernel_name}, + &has_db); + if (has_db) { continue; } + + // Builds the parameter database for this device and routine set and stores it in the cache + log_debug("Searching database for kernel '" + kernel_name + "'"); + db(kernel_name) = Database(device, kernel_name, precision, userDatabase); + DatabaseCache::Instance().Store(DatabaseKey{platform_id, device(), precision, kernel_name}, + Database{db(kernel_name)}); + } + } + // Base class constructor. The user database is an optional extra database to override the // built-in database. // All heavy preparation work is done inside this constructor. diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index edba1f00..4c1b9558 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -19,6 +19,11 @@ namespace clblast { // ================================================================================================= +// Defines the assumptions of the GEMM kernels +template <typename T> const bool Xgemm<T>::a_want_rotated_ = false; +template <typename T> const bool Xgemm<T>::b_want_rotated_ = true; +template <typename T> const bool Xgemm<T>::c_want_rotated_ = false; + // Constructor: forwards to base class constructor template <typename T> Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): @@ -56,40 +61,15 @@ void Xgemm<T>::DoGemm(const Layout layout, const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const T beta, - const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld) { - - // Makes sure all dimensions are larger than zero - if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); } - - // Computes whether or not the matrices are transposed in memory. This is based on their layout - // (row or column-major) and whether or not they are requested to be pre-transposed. Note - // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of - // col-major) to be transformed, so transposing requirements are not the same as whether or not - // the matrix is actually transposed in memory. - const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) || - (layout == Layout::kRowMajor && a_transpose == Transpose::kNo); - const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) || - (layout == Layout::kRowMajor && b_transpose == Transpose::kNo); - const auto c_rotated = (layout == Layout::kRowMajor); - static const auto a_want_rotated = false; - static const auto b_want_rotated = true; - static const auto c_want_rotated = false; - const auto a_do_transpose = a_rotated != a_want_rotated; - const auto b_do_transpose = b_rotated != b_want_rotated; - const auto c_do_transpose = c_rotated != c_want_rotated; - - // In case of complex data-types, the transpose can also become a conjugate transpose - const auto a_conjugate = (a_transpose == Transpose::kConjugate); - const auto b_conjugate = (b_transpose == Transpose::kConjugate); - - // Computes the first and second dimensions of the 3 matrices taking into account whether the - // matrices are rotated or not - const auto a_one = (a_rotated) ? k : m; - const auto a_two = (a_rotated) ? m : k; - const auto b_one = (b_rotated) ? n : k; - const auto b_two = (b_rotated) ? k : n; - const auto c_one = (c_rotated) ? n : m; - const auto c_two = (c_rotated) ? m : n; + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { // optional arguments + + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that + bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; + size_t a_one, a_two, b_one, b_two, c_one, c_two; + ProcessArguments(layout, a_transpose, b_transpose, m, n, k, + a_one, a_two, b_one, b_two, c_one, c_two, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); // Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and // their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the @@ -103,11 +83,7 @@ void Xgemm<T>::DoGemm(const Layout layout, TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld); // Selects which version of GEMM to run - const auto m_n_k = static_cast<unsigned long long>(m) * static_cast<unsigned long long>(n) * - static_cast<unsigned long long>(k); - const auto database_value = static_cast<unsigned long long>(db_["XGEMM_MIN_INDIRECT_SIZE"]); - const auto min_indirect_size = database_value * database_value * database_value; - const auto do_gemm_direct = (m_n_k < min_indirect_size); + const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]); if (do_gemm_direct) { // for small sizes (single kernel) GemmDirect(m, n, k, alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, @@ -119,9 +95,8 @@ void Xgemm<T>::DoGemm(const Layout layout, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, c_buffer, c_offset, c_ld, a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, - a_one, a_two, a_want_rotated, - b_one, b_two, b_want_rotated, - c_one, c_two, c_want_rotated); + a_one, a_two, b_one, b_two, c_one, c_two, + temp_buffer, temp_buffer_provided); } } @@ -139,9 +114,11 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, const bool a_conjugate, const bool b_conjugate, - const size_t a_one, const size_t a_two, const bool a_want_rotated, - const size_t b_one, const size_t b_two, const bool b_want_rotated, - const size_t c_one, const size_t c_two, const bool c_want_rotated) { + const size_t a_one, const size_t a_two, + const size_t b_one, const size_t b_two, + const size_t c_one, const size_t c_two, + const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { + // Calculates the ceiled versions of m, n, and k const auto m_ceiled = Ceil(m, db_["MWG"]); const auto n_ceiled = Ceil(n, db_["NWG"]); @@ -149,39 +126,39 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k, // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account // whether the matrices need to be rotated or not for the kernel. - const auto a_one_i = (a_want_rotated) ? k_ceiled : m_ceiled; - const auto a_two_i = (a_want_rotated) ? m_ceiled : k_ceiled; - const auto b_one_i = (b_want_rotated) ? n_ceiled : k_ceiled; - const auto b_two_i = (b_want_rotated) ? k_ceiled : n_ceiled; - const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled; - const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled; + size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; + CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"], + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); // Determines whether or not temporary matrices are needed - auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offset == 0 && - a_do_transpose == false && a_conjugate == false; - auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offset == 0 && - b_do_transpose == false && b_conjugate == false; - auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offset == 0 && - c_do_transpose == false; + auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate); + auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate); + auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false); // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices - auto temp_size = size_t{0}; auto b_temp_offset = size_t{0}; auto c_temp_offset = size_t{0}; - if (!a_no_temp) { temp_size += a_one_i*a_two_i; } - if (!b_no_temp) { b_temp_offset = temp_size; temp_size += b_one_i*b_two_i; } - if (!c_no_temp) { c_temp_offset = temp_size; temp_size += c_one_i*c_two_i; } + const auto temp_size = ComputeTempSize(a_no_temp, b_no_temp, c_no_temp, + a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i, + b_temp_offset, c_temp_offset); if (!IsMultiple(b_temp_offset, db_["VWN"])) { throw BLASError(StatusCode::kUnexpectedError); } if (!IsMultiple(c_temp_offset, db_["VWM"])) { throw BLASError(StatusCode::kUnexpectedError); } // Creates the buffer for the (optional) temporary matrices. Note that we use 'a_buffer' in case // when no temporary buffer is needed, but that's just to make it compile: it is never used. - const auto temp_buffer = (temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer; + const auto temp_buffer_all = (temp_buffer_provided) ? temp_buffer : + ((temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer); + + // Verifies if the provided temporary buffer is large enough + if (temp_buffer_provided) { + const auto required_size = temp_size * sizeof(T); + if (temp_buffer_all.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryTemp); } + } // Sets the buffer pointers for (temp) matrices A, B, and C - const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer; - const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer; - const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer; + const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer_all; + const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer_all; + const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer_all; // Events of all kernels (including pre/post processing kernels) auto eventWaitList = std::vector<Event>(); diff --git a/src/routines/level3/xgemm.hpp b/src/routines/level3/xgemm.hpp index c61611b6..b51d1771 100644 --- a/src/routines/level3/xgemm.hpp +++ b/src/routines/level3/xgemm.hpp @@ -24,6 +24,130 @@ template <typename T> class Xgemm: public Routine { public: + // Defines the assumptions of the GEMM kernels + static const bool a_want_rotated_; + static const bool b_want_rotated_; + static const bool c_want_rotated_; + + // Computes the size of the temporary GEMM buffer based on user-arguments + static size_t GetTempSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + const size_t mwg, const size_t nwg, const size_t kwg) { + + // Computes the transpose/conjugate options and sets the a/b/c sizes based on that + bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate; + size_t a_one, a_two, b_one, b_two, c_one, c_two; + ProcessArguments(layout, a_transpose, b_transpose, m, n, k, + a_one, a_two, b_one, b_two, c_one, c_two, + a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); + + // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account + // whether the matrices need to be rotated or not for the kernel. + size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i; + CalculateInternalDimensions(m, n, k, mwg, nwg, kwg, + a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i); + + // Determines whether or not temporary matrices are needed + auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate); + auto b_no_temp = NoTempBuffer(b_one, b_one_i, b_two, b_two_i, b_ld, b_offset, b_do_transpose, b_conjugate); + auto c_no_temp = NoTempBuffer(c_one, c_one_i, c_two, c_two_i, c_ld, c_offset, c_do_transpose, false); + + // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices + auto b_temp_offset = size_t{0}; + auto c_temp_offset = size_t{0}; + return ComputeTempSize(a_no_temp, b_no_temp, c_no_temp, + a_one_i*a_two_i, b_one_i*b_two_i, c_one_i*c_two_i, + b_temp_offset, c_temp_offset); + } + + // Selects which version of GEMM to run + static bool UseDirectKernel(const size_t m, const size_t n, const size_t k, + const size_t min_indirect_size) { + const auto m_n_k = static_cast<unsigned long long>(m) * static_cast<unsigned long long>(n) * + static_cast<unsigned long long>(k); + const auto min_indirect_size_ll = static_cast<unsigned long long>(min_indirect_size); + const auto min_indirect_size_e3 = min_indirect_size_ll * min_indirect_size_ll * min_indirect_size_ll; + return (m_n_k < min_indirect_size_e3); + } + + // Process the user-arguments, computes secondary parameters + static void ProcessArguments(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + size_t& a_one, size_t& a_two, size_t& b_one, + size_t& b_two, size_t& c_one, size_t& c_two, + bool& a_do_transpose, bool& b_do_transpose, bool& c_do_transpose, + bool& a_conjugate, bool& b_conjugate) { + + // Makes sure all dimensions are larger than zero + if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); } + + // Computes whether or not the matrices are transposed in memory. This is based on their layout + // (row or column-major) and whether or not they are requested to be pre-transposed. Note + // that the Xgemm kernel expects either matrices A and C (in case of row-major) or B (in case of + // col-major) to be transformed, so transposing requirements are not the same as whether or not + // the matrix is actually transposed in memory. + const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) || + (layout == Layout::kRowMajor && a_transpose == Transpose::kNo); + const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) || + (layout == Layout::kRowMajor && b_transpose == Transpose::kNo); + const auto c_rotated = (layout == Layout::kRowMajor); + a_do_transpose = a_rotated != a_want_rotated_; + b_do_transpose = b_rotated != b_want_rotated_; + c_do_transpose = c_rotated != c_want_rotated_; + + // In case of complex data-types, the transpose can also become a conjugate transpose + a_conjugate = (a_transpose == Transpose::kConjugate); + b_conjugate = (b_transpose == Transpose::kConjugate); + + // Computes the first and second dimensions of the 3 matrices taking into account whether the + // matrices are rotated or not + a_one = (a_rotated) ? k : m; + a_two = (a_rotated) ? m : k; + b_one = (b_rotated) ? n : k; + b_two = (b_rotated) ? k : n; + c_one = (c_rotated) ? n : m; + c_two = (c_rotated) ? m : n; + } + + // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices + static size_t ComputeTempSize(const bool a_no_temp, const bool b_no_temp, const bool c_no_temp, + const size_t a_size, const size_t b_size, const size_t c_size, + size_t &b_temp_offset, size_t &c_temp_offset) { + auto temp_size = size_t{0}; + if (!a_no_temp) { temp_size += a_size; } + if (!b_no_temp) { b_temp_offset = temp_size; temp_size += b_size; } + if (!c_no_temp) { c_temp_offset = temp_size; temp_size += c_size; } + return temp_size; + } + + // Determines whether or not temporary matrices are needed + static bool NoTempBuffer(const size_t one, const size_t one_i, const size_t two, const size_t two_i, + const size_t ld, const size_t offset, + const bool do_transpose, const bool conjugate) { + return one == one_i && two == two_i && ld == one && offset == 0 && !do_transpose && !conjugate; + } + + + // Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account + // whether the matrices need to be rotated or not for the kernel. + static void CalculateInternalDimensions(const size_t m, const size_t n, const size_t k, + const size_t mwg, const size_t nwg, const size_t kwg, + size_t& a_one_i, size_t& a_two_i, size_t& b_one_i, + size_t& b_two_i, size_t& c_one_i, size_t& c_two_i) { + const auto m_ceiled = Ceil(m, mwg); + const auto n_ceiled = Ceil(n, nwg); + const auto k_ceiled = Ceil(k, kwg); + a_one_i = (a_want_rotated_) ? k_ceiled : m_ceiled; + a_two_i = (a_want_rotated_) ? m_ceiled : k_ceiled; + b_one_i = (b_want_rotated_) ? n_ceiled : k_ceiled; + b_two_i = (b_want_rotated_) ? k_ceiled : n_ceiled; + c_one_i = (c_want_rotated_) ? n_ceiled : m_ceiled; + c_two_i = (c_want_rotated_) ? m_ceiled : n_ceiled; + } + // Constructor Xgemm(Queue &queue, EventPointer event, const std::string &name = "GEMM"); @@ -34,7 +158,8 @@ class Xgemm: public Routine { const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const T beta, - const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld); + const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, + const Buffer<T> &temp_buffer = Buffer<T>(0), const bool temp_buffer_provided = false); // Indirect version of GEMM (with pre and post-processing kernels) void GemmIndirect(const size_t m, const size_t n, const size_t k, @@ -45,9 +170,10 @@ class Xgemm: public Routine { const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, const bool a_conjugate, const bool b_conjugate, - const size_t a_one, const size_t a_two, const bool a_want_rotated, - const size_t b_one, const size_t b_two, const bool b_want_rotated, - const size_t c_one, const size_t c_two, const bool c_want_rotated); + const size_t a_one, const size_t a_two, + const size_t b_one, const size_t b_two, + const size_t c_one, const size_t c_two, + const Buffer<T> &temp_buffer, const bool temp_buffer_provided); // Direct version of GEMM (no pre and post-processing kernels) void GemmDirect(const size_t m, const size_t n, const size_t k, |