summaryrefslogtreecommitdiff
path: root/src/tuning/routines/xgemm.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-18 19:41:59 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-18 19:41:59 +0100
commitc3f9371d16a66fa28906a3be9925a646e72ea471 (patch)
tree086cf7a11e80fc024fa99fb132a6c5fc9eac1c8a /src/tuning/routines/xgemm.cpp
parentbc54411d193b700ef5e0e1ad08c6311597bd433d (diff)
Made GEMM routine tuning a bit more generic in preparation of possible separate batched tuning arguments
Diffstat (limited to 'src/tuning/routines/xgemm.cpp')
-rw-r--r--src/tuning/routines/xgemm.cpp68
1 files changed, 50 insertions, 18 deletions
diff --git a/src/tuning/routines/xgemm.cpp b/src/tuning/routines/xgemm.cpp
index de4bef71..0721ad7c 100644
--- a/src/tuning/routines/xgemm.cpp
+++ b/src/tuning/routines/xgemm.cpp
@@ -39,6 +39,46 @@ void RunGemmRoutine(const size_t value, const Queue& queue, const std::vector<Bu
clReleaseEvent(event);
}
+template <typename T, size_t batch_count>
+void RunGemmBatchedRoutine(const size_t value, const Queue& queue, const std::vector<Buffer<T>>& buffers) {
+ auto offsets = std::vector<size_t>(batch_count);
+ auto factors = std::vector<T>(batch_count);
+ for (auto i = size_t{0}; i < batch_count; ++i) {
+ offsets[i] = batch_count * value;
+ factors[i] = ConstantOne<T>();
+ }
+ auto queue_plain = queue();
+ auto event = cl_event{};
+ auto status = GemmBatched(Layout::kRowMajor, Transpose::kNo, Transpose::kNo,
+ value, value, value, factors.data(),
+ buffers[0](), offsets.data(), value,
+ buffers[1](), offsets.data(), value, factors.data(),
+ buffers[2](), offsets.data(), value, batch_count,
+ &queue_plain, &event);
+ if (status != StatusCode::kSuccess) {
+ throw RuntimeError("GemmBatched failed with status " + ToString(status));
+ }
+ clWaitForEvents(1, &event);
+ clReleaseEvent(event);
+}
+
+template <typename T, size_t batch_count>
+void RunGemmStridedBatchedRoutine(const size_t value, const Queue& queue, const std::vector<Buffer<T>>& buffers) {
+ auto queue_plain = queue();
+ auto event = cl_event{};
+ auto status = GemmStridedBatched(Layout::kRowMajor, Transpose::kNo, Transpose::kNo,
+ value, value, value, ConstantOne<T>(),
+ buffers[0](), 0, value, value * value,
+ buffers[1](), 0, value, value * value, ConstantOne<T>(),
+ buffers[2](), 0, value, value * value, batch_count,
+ &queue_plain, &event);
+ if (status != StatusCode::kSuccess) {
+ throw RuntimeError("Gemm failed with status " + ToString(status));
+ }
+ clWaitForEvents(1, &event);
+ clReleaseEvent(event);
+}
+
// =================================================================================================
template <typename T>
@@ -61,24 +101,16 @@ void TuneXgemm(int argc, char* argv[]) {
const auto context = Context(device);
auto queue = Queue(context, device);
- // Run the tuners for the XGEMM routine
- const auto scores = TuneKernelSelection<T>(device, context, queue, precision, RunGemmRoutine<T>,
- num_runs, "Gemm", "XGEMM_MIN_INDIRECT_SIZE");
- const auto xgemm_best = GetBestResult(scores);
- const auto xgemm_switching_point = xgemm_best.config.at("XGEMM_MIN_INDIRECT_SIZE");
- const auto xgemm_string = "XGEMM_MIN_INDIRECT_SIZE=" + ToString(xgemm_switching_point);
-
- // Outputs the results as JSON to disk, including some meta-data
- const auto precision_string = std::to_string(static_cast<size_t>(precision));
- auto metadata = std::vector<std::pair<std::string,std::string>>{
- {"kernel_family", "gemm_routine"},
- {"precision", precision_string},
- {"best_kernel", xgemm_best.name},
- {"best_time", ToString(xgemm_best.score)},
- {"best_parameters", xgemm_string}
- };
- PrintTimingsToFileAsJSON("clblast_routine_gemm_" + precision_string + ".json",
- device, platform, metadata, scores);
+ // Run the tuners for the XGEMM routines
+ TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmRoutine<T>,
+ 64, 2048, 64, 1, num_runs,
+ "gemm", "GemmRoutine", "gemm_routine", "XGEMM_MIN_INDIRECT_SIZE");
+ //TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmBatchedRoutine<T, 30>,
+ // 16, 128, 32, 30, num_runs,
+ // "gemmbatched", "GemmRoutine", "gemm_routine_2", "XGEMMBATCHED_MIN_INDIRECT_SIZE");
+ //TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmStridedBatchedRoutine<T, 30>,
+ // 16, 128, 32, 30, num_runs,
+ // "gemmstridedbatched", "GemmRoutine", "gemm_routine_3", "XGEMMSTRIDEDBATCHED_MIN_INDIRECT_SIZE");
printf("* Completed tuning process\n");
printf("\n");