summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-10-07 18:41:46 +0200
committerGitHub <noreply@github.com>2017-10-07 18:41:46 +0200
commitb2058320d12cbb3a250f62e9c85ff8c840659ee7 (patch)
treefd406dd26d17d5282e9864a3a8c7634b25e1cbfa
parent1009303717d1722fd01ef43d1a08ee9d0899ad41 (diff)
parent86b80cdc98a36e75023403ba6c54df9f5b94854f (diff)
Merge pull request #197 from CNugteren/single_temporary_gemm_buffer
Single temporary GEMM buffer
-rw-r--r--CHANGELOG2
-rw-r--r--src/clblast.cpp9
-rw-r--r--src/kernels/level3/xgemm_part3.opencl9
-rw-r--r--src/routines/level3/xgemm.cpp30
-rw-r--r--src/tuning/kernels/xgemm.cpp2
-rw-r--r--test/correctness/misc/override_parameters.cpp2
-rw-r--r--test/correctness/routines/level3/xgemm.cpp20
-rw-r--r--test/performance/routines/level3/xgemm.cpp10
-rw-r--r--test/routines/level3/xgemm.hpp9
9 files changed, 70 insertions, 23 deletions
diff --git a/CHANGELOG b/CHANGELOG
index 56dad68e..bb2013a6 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -1,6 +1,8 @@
Development (next version)
- Kernels are now cached based on their tuning parameters: fits the use-case of 'OverrideParameters'
+- Improved performance for small GEMM problems by going from 3 to 1 optional temporary buffers
+- Various minor fixes and enhancements
Version 1.1.0
- The tuning database now has defaults per architecture (e.g. NVIDIA Kepler SM3.5, AMD Fiji)
diff --git a/src/clblast.cpp b/src/clblast.cpp
index bb338503..19d7ef0a 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2497,8 +2497,13 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
// Retrieves the current database values to verify whether the new ones are complete
auto in_cache = false;
- const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache);
- if (!in_cache) { return StatusCode::kInvalidOverrideKernel; }
+ auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache);
+ if (!in_cache) {
+ log_debug("Searching database for kernel '" + kernel_name + "'");
+ current_database = Database(device_cpp, kernel_name, precision, {});
+ }
+
+ // Verifies the parameters size
const auto current_parameter_names = current_database.GetParameterNames();
if (current_parameter_names.size() != parameters.size()) {
return StatusCode::kMissingOverrideParameter;
diff --git a/src/kernels/level3/xgemm_part3.opencl b/src/kernels/level3/xgemm_part3.opencl
index 3f0d590d..f447677f 100644
--- a/src/kernels/level3/xgemm_part3.opencl
+++ b/src/kernels/level3/xgemm_part3.opencl
@@ -17,7 +17,7 @@ R"(
// =================================================================================================
-// Main body of the matrix-multiplication algorithm. It calls the (inlined) functions above.
+// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.
INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
const __global realM* restrict agm, const __global realN* restrict bgm,
__global realM* cgm, realM cpm[NWI][MWI/VWM]
@@ -192,10 +192,15 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
const real_arg arg_beta,
const __global realM* restrict agm,
const __global realN* restrict bgm,
- __global realM* cgm) {
+ __global realM* cgm,
+ const int b_offset, const int c_offset) {
const real alpha = GetRealArg(arg_alpha);
const real beta = GetRealArg(arg_beta);
+ // Adds the offsets (in case of use of a single temporary buffer for A, B, and C)
+ bgm = &bgm[b_offset];
+ cgm = &cgm[c_offset];
+
// Allocates workgroup-private memory (local memory)
#if SA == 1
__local realM alm[KWG * MWG/VWM];
diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp
index 3909c308..a0063ee2 100644
--- a/src/routines/level3/xgemm.cpp
+++ b/src/routines/level3/xgemm.cpp
@@ -161,10 +161,24 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offset == 0 &&
c_do_transpose == false;
- // Creates the temporary matrices
- const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, a_one_i*a_two_i);
- const auto b_temp = (b_no_temp) ? b_buffer : Buffer<T>(context_, b_one_i*b_two_i);
- const auto c_temp = (c_no_temp) ? c_buffer : Buffer<T>(context_, c_one_i*c_two_i);
+ // Computes the sizes and offsets for (optional) temporary buffers for the 3 matrices
+ auto temp_size = size_t{0};
+ auto b_temp_offset = size_t{0};
+ auto c_temp_offset = size_t{0};
+ if (!a_no_temp) { temp_size += a_one_i*a_two_i; }
+ if (!b_no_temp) { b_temp_offset = temp_size; temp_size += b_one_i*b_two_i; }
+ if (!c_no_temp) { c_temp_offset = temp_size; temp_size += c_one_i*c_two_i; }
+ if (!IsMultiple(b_temp_offset, db_["VWN"])) { throw BLASError(StatusCode::kUnexpectedError); }
+ if (!IsMultiple(c_temp_offset, db_["VWM"])) { throw BLASError(StatusCode::kUnexpectedError); }
+
+ // Creates the buffer for the (optional) temporary matrices. Note that we use 'a_buffer' in case
+ // when no temporary buffer is needed, but that's just to make it compile: it is never used.
+ const auto temp_buffer = (temp_size > 0) ? Buffer<T>(context_, temp_size) : a_buffer;
+
+ // Sets the buffer pointers for (temp) matrices A, B, and C
+ const auto a_temp = (a_no_temp) ? a_buffer : temp_buffer;
+ const auto b_temp = (b_no_temp) ? b_buffer : temp_buffer;
+ const auto c_temp = (c_no_temp) ? c_buffer : temp_buffer;
// Events of all kernels (including pre/post processing kernels)
auto eventWaitList = std::vector<Event>();
@@ -188,7 +202,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
auto eventProcessB = Event();
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
b_one, b_two, b_ld, b_offset, b_buffer,
- b_one_i, b_two_i, b_one_i, 0, b_temp,
+ b_one_i, b_two_i, b_one_i, b_temp_offset, b_temp,
ConstantOne<T>(), program_,
true, b_do_transpose, b_conjugate);
eventWaitList.push_back(eventProcessB);
@@ -199,7 +213,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
auto eventProcessC = Event();
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
c_one, c_two, c_ld, c_offset, c_buffer,
- c_one_i, c_two_i, c_one_i, 0, c_temp,
+ c_one_i, c_two_i, c_one_i, c_temp_offset, c_temp,
ConstantOne<T>(), program_,
true, c_do_transpose, false);
eventWaitList.push_back(eventProcessC);
@@ -217,6 +231,8 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
kernel.SetArgument(5, a_temp());
kernel.SetArgument(6, b_temp());
kernel.SetArgument(7, c_temp());
+ kernel.SetArgument(8, static_cast<int>(b_temp_offset / db_["VWN"]));
+ kernel.SetArgument(9, static_cast<int>(c_temp_offset / db_["VWM"]));
// Computes the global and local thread sizes
const auto global = std::vector<size_t>{
@@ -234,7 +250,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
if (!c_no_temp) {
eventWaitList.push_back(eventKernel);
PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList,
- c_one_i, c_two_i, c_one_i, 0, c_temp,
+ c_one_i, c_two_i, c_one_i, c_temp_offset, c_temp,
c_one, c_two, c_ld, c_offset, c_buffer,
ConstantOne<T>(), program_,
false, c_do_transpose, false);
diff --git a/src/tuning/kernels/xgemm.cpp b/src/tuning/kernels/xgemm.cpp
index 7d0f3ed4..6dcdf68b 100644
--- a/src/tuning/kernels/xgemm.cpp
+++ b/src/tuning/kernels/xgemm.cpp
@@ -180,6 +180,8 @@ class TuneXgemm {
tuner.AddArgumentInput(a_mat);
tuner.AddArgumentInput(b_mat);
tuner.AddArgumentOutput(c_mat);
+ tuner.AddArgumentScalar(0);
+ tuner.AddArgumentScalar(0);
}
};
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp
index 95ece98c..05f40f57 100644
--- a/test/correctness/misc/override_parameters.cpp
+++ b/test/correctness/misc/override_parameters.cpp
@@ -28,7 +28,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st
auto arguments = RetrieveCommandLineArguments(argc, argv);
auto errors = size_t{0};
auto passed = size_t{0};
- auto example_routine = TestXgemm<T>();
+ auto example_routine = TestXgemm<0, T>();
constexpr auto kSeed = 42; // fixed seed for reproducibility
// Determines the test settings
diff --git a/test/correctness/routines/level3/xgemm.cpp b/test/correctness/routines/level3/xgemm.cpp
index 5de73554..bdf57b36 100644
--- a/test/correctness/routines/level3/xgemm.cpp
+++ b/test/correctness/routines/level3/xgemm.cpp
@@ -15,11 +15,21 @@
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
auto errors = size_t{0};
- errors += clblast::RunTests<clblast::TestXgemm<float>, float, float>(argc, argv, false, "SGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<double>, double, double>(argc, argv, true, "DGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
+
+ // Tests GEMM based on the 'in-direct' kernel
+ errors += clblast::RunTests<clblast::TestXgemm<1, float>, float, float>(argc, argv, false, "SGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, double>, double, double>(argc, argv, true, "DGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
+
+ // Tests GEMM based on the 'direct' kernel
+ errors += clblast::RunTests<clblast::TestXgemm<2, float>, float, float>(argc, argv, true, "SGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, double>, double, double>(argc, argv, true, "DGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
+
if (errors > 0) { return 1; } else { return 0; }
}
diff --git a/test/performance/routines/level3/xgemm.cpp b/test/performance/routines/level3/xgemm.cpp
index 5b3426f5..0b67b4d3 100644
--- a/test/performance/routines/level3/xgemm.cpp
+++ b/test/performance/routines/level3/xgemm.cpp
@@ -17,15 +17,15 @@ int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) {
case clblast::Precision::kHalf:
- clblast::RunClient<clblast::TestXgemm<clblast::half>, clblast::half, clblast::half>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, clblast::half>, clblast::half, clblast::half>(argc, argv); break;
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXgemm<float>, float, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXgemm<double>, double, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, clblast::float2>, clblast::float2, clblast::float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, clblast::double2>, clblast::double2, clblast::double2>(argc, argv); break;
}
return 0;
}
diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp
index 7e0ead6d..1c430c1c 100644
--- a/test/routines/level3/xgemm.hpp
+++ b/test/routines/level3/xgemm.hpp
@@ -22,7 +22,7 @@ namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
-template <typename T>
+template <int V, typename T> // 'V' is the version of the kernel (0 for default, 1 for 'in-direct', 2 for 'direct')
class TestXgemm {
public:
@@ -83,6 +83,13 @@ class TestXgemm {
// Describes how to run the CLBlast routine
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ if (V != 0) {
+ const auto device = queue.GetDevice();
+ const auto switch_threshold = (V == 1) ? size_t{0} : size_t{1024 * 1024 * 1024}; // large enough for tests
+ const auto override_status = OverrideParameters(device(), "KernelSelection", PrecisionValue<T>(),
+ {{"XGEMM_MIN_INDIRECT_SIZE", switch_threshold}});
+ if (override_status != StatusCode::kSuccess) { return override_status; }
+ }
auto queue_plain = queue();
auto event = cl_event{};
auto status = Gemm(args.layout, args.a_transpose, args.b_transpose,