summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-18 19:41:59 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-18 19:41:59 +0100
commitc3f9371d16a66fa28906a3be9925a646e72ea471 (patch)
tree086cf7a11e80fc024fa99fb132a6c5fc9eac1c8a
parentbc54411d193b700ef5e0e1ad08c6311597bd433d (diff)
Made GEMM routine tuning a bit more generic in preparation of possible separate batched tuning arguments
-rw-r--r--src/tuning/routines/routine_tuner.hpp72
-rw-r--r--src/tuning/routines/xgemm.cpp68
2 files changed, 93 insertions, 47 deletions
diff --git a/src/tuning/routines/routine_tuner.hpp b/src/tuning/routines/routine_tuner.hpp
index c02c1b84..2aa0b3ce 100644
--- a/src/tuning/routines/routine_tuner.hpp
+++ b/src/tuning/routines/routine_tuner.hpp
@@ -28,41 +28,44 @@ namespace clblast {
template <typename T>
void ForceSelectIndirectFrom(const size_t minimum_size, const Device &device,
- const std::string &name, const std::string& parameter_name) {
- const auto override_status = OverrideParameters(device(), name + "Routine", PrecisionValue<T>(),
+ const std::string &tuner_name, const std::string& parameter_name) {
+ const auto override_status = OverrideParameters(device(), tuner_name, PrecisionValue<T>(),
{{parameter_name, minimum_size}});
if (override_status != StatusCode::kSuccess) {
throw RuntimeError("OverrideParameters failed with status " + ToString(override_status));
}
}
-template <typename T, typename F>
-std::vector<TuningResult> TuneKernelSelection(const Device& device, const Context& context,
- Queue& queue, const Precision precision,
- F const &routine, const size_t num_runs,
- const std::string &name,
- const std::string& parameter_name) {
+// Computes the best switching point
+TuningResult GetBestResult(const std::vector<TuningResult>& scores) {
+ auto comparison = [](const TuningResult& lhs, const TuningResult& rhs) { return lhs.score < rhs.score; };
+ const auto best_configuration = std::min_element(scores.begin(), scores.end(), comparison);
+ return *best_configuration;
+}
- // Values for m, n, and k
- const auto from = size_t{64};
- const auto to = size_t{2048};
- const auto step = size_t{64};
+// Tunes at kernel-level
+template <typename T, typename F>
+void TuneKernelSelection(const Platform& platform, const Device& device, const Context& context,
+ Queue& queue, const Precision precision, F const &routine,
+ const size_t from, const size_t to, const size_t step, const size_t batch_count,
+ const size_t num_runs, const std::string &name, const std::string &tuner_name,
+ const std::string &family_name, const std::string& parameter_name) {
// Buffers
auto buffers = std::vector<Buffer<T>>{
- Buffer<T>(context, to * to),
- Buffer<T>(context, to * to),
- Buffer<T>(context, to * to)
+ Buffer<T>(context, to * to * batch_count),
+ Buffer<T>(context, to * to * batch_count),
+ Buffer<T>(context, to * to * batch_count)
};
// In-direct version
printf("\n* Testing the in-direct %s routine for m=n=k\n", name.c_str());
- ForceSelectIndirectFrom<T>(0, device, name, parameter_name);
+ ForceSelectIndirectFrom<T>(0, device, tuner_name, parameter_name);
const auto indirect = TimeRoutine(from, to, step, num_runs, queue, buffers, routine);
// Direct version
printf("\n* Testing the direct %s routine for m=n=k\n", name.c_str());
- ForceSelectIndirectFrom<T>(to + 1, device, name, parameter_name);
+ ForceSelectIndirectFrom<T>(batch_count * to + 1, device, tuner_name, parameter_name);
const auto direct = TimeRoutine(from, to, step, num_runs, queue, buffers, routine);
// Determining final score and best kernel selection point
@@ -90,29 +93,40 @@ std::vector<TuningResult> TuneKernelSelection(const Device& device, const Contex
}
// Displaying results
- printf("| || %8s indirect || %8s direct || |\n", name.c_str(), name.c_str());
- printf("| m=n=k || ms | GFLOPS || ms | GFLOPS || score | (lowest score == best switching point)\n");
- printf("x---------xx--------x----------xx--------x----------xx----------x\n");
+ printf("| || %12s indirect || %12s direct || |\n", name.c_str(), name.c_str());
+ printf("| m=n=k || ms | GFLOPS || ms | GFLOPS || score | (lowest score == best switching point)\n");
+ printf("x---------xx----------x------------xx----------x----------xx----------x\n");
for (auto i = size_t{0}; i < indirect.size(); ++i) {
assert(indirect[i].first == direct[i].first);
const auto value = indirect[i].first;
if (indirect[i].second != -1 && direct[i].second != -1) {
const auto gflops_indirect = (2 * value * value * value) / (indirect[i].second * 1.0e6);
const auto gflops_direct = (2 * value * value * value) / (direct[i].second * 1.0e6);
- printf("| %7zu || %6.2lf | %8.1lf || %6.2lf | %8.1lf || %8.3lf |\n",
+ printf("| %7zu || %8.2lf | %10.1lf || %8.2lf | %8.1lf || %8.3lf |\n",
value, indirect[i].second, gflops_indirect, direct[i].second, gflops_direct, scores[i].score);
}
}
- printf("x---------xx--------x----------xx--------x----------xx----------x\n");
+ printf("x---------xx----------x------------xx----------x----------xx----------x\n");
printf("\n");
- return scores;
-}
-// Computes the best switching point
-TuningResult GetBestResult(const std::vector<TuningResult>& scores) {
- auto comparison = [](const TuningResult& lhs, const TuningResult& rhs) { return lhs.score < rhs.score; };
- const auto best_configuration = std::min_element(scores.begin(), scores.end(), comparison);
- return *best_configuration;
+ const auto best_result = GetBestResult(scores);
+ const auto best_switching_point = best_result.config.at(parameter_name);
+ const auto best_string = parameter_name + "=" + ToString(best_switching_point);
+
+ // Outputs the results as JSON to disk, including some meta-data
+ const auto precision_string = std::to_string(static_cast<size_t>(precision));
+ auto metadata = std::vector<std::pair<std::string,std::string>>{
+ {"kernel_family", family_name},
+ {"precision", precision_string},
+ {"arg_from", ToString(from)},
+ {"arg_to", ToString(to)},
+ {"arg_step", ToString(step)},
+ {"best_kernel", best_result.name},
+ {"best_time", ToString(best_result.score)},
+ {"best_parameters", best_string}
+ };
+ PrintTimingsToFileAsJSON("clblast_" + family_name + "_" + precision_string + ".json",
+ device, platform, metadata, scores);
}
// =================================================================================================
diff --git a/src/tuning/routines/xgemm.cpp b/src/tuning/routines/xgemm.cpp
index de4bef71..0721ad7c 100644
--- a/src/tuning/routines/xgemm.cpp
+++ b/src/tuning/routines/xgemm.cpp
@@ -39,6 +39,46 @@ void RunGemmRoutine(const size_t value, const Queue& queue, const std::vector<Bu
clReleaseEvent(event);
}
+template <typename T, size_t batch_count>
+void RunGemmBatchedRoutine(const size_t value, const Queue& queue, const std::vector<Buffer<T>>& buffers) {
+ auto offsets = std::vector<size_t>(batch_count);
+ auto factors = std::vector<T>(batch_count);
+ for (auto i = size_t{0}; i < batch_count; ++i) {
+ offsets[i] = batch_count * value;
+ factors[i] = ConstantOne<T>();
+ }
+ auto queue_plain = queue();
+ auto event = cl_event{};
+ auto status = GemmBatched(Layout::kRowMajor, Transpose::kNo, Transpose::kNo,
+ value, value, value, factors.data(),
+ buffers[0](), offsets.data(), value,
+ buffers[1](), offsets.data(), value, factors.data(),
+ buffers[2](), offsets.data(), value, batch_count,
+ &queue_plain, &event);
+ if (status != StatusCode::kSuccess) {
+ throw RuntimeError("GemmBatched failed with status " + ToString(status));
+ }
+ clWaitForEvents(1, &event);
+ clReleaseEvent(event);
+}
+
+template <typename T, size_t batch_count>
+void RunGemmStridedBatchedRoutine(const size_t value, const Queue& queue, const std::vector<Buffer<T>>& buffers) {
+ auto queue_plain = queue();
+ auto event = cl_event{};
+ auto status = GemmStridedBatched(Layout::kRowMajor, Transpose::kNo, Transpose::kNo,
+ value, value, value, ConstantOne<T>(),
+ buffers[0](), 0, value, value * value,
+ buffers[1](), 0, value, value * value, ConstantOne<T>(),
+ buffers[2](), 0, value, value * value, batch_count,
+ &queue_plain, &event);
+ if (status != StatusCode::kSuccess) {
+ throw RuntimeError("Gemm failed with status " + ToString(status));
+ }
+ clWaitForEvents(1, &event);
+ clReleaseEvent(event);
+}
+
// =================================================================================================
template <typename T>
@@ -61,24 +101,16 @@ void TuneXgemm(int argc, char* argv[]) {
const auto context = Context(device);
auto queue = Queue(context, device);
- // Run the tuners for the XGEMM routine
- const auto scores = TuneKernelSelection<T>(device, context, queue, precision, RunGemmRoutine<T>,
- num_runs, "Gemm", "XGEMM_MIN_INDIRECT_SIZE");
- const auto xgemm_best = GetBestResult(scores);
- const auto xgemm_switching_point = xgemm_best.config.at("XGEMM_MIN_INDIRECT_SIZE");
- const auto xgemm_string = "XGEMM_MIN_INDIRECT_SIZE=" + ToString(xgemm_switching_point);
-
- // Outputs the results as JSON to disk, including some meta-data
- const auto precision_string = std::to_string(static_cast<size_t>(precision));
- auto metadata = std::vector<std::pair<std::string,std::string>>{
- {"kernel_family", "gemm_routine"},
- {"precision", precision_string},
- {"best_kernel", xgemm_best.name},
- {"best_time", ToString(xgemm_best.score)},
- {"best_parameters", xgemm_string}
- };
- PrintTimingsToFileAsJSON("clblast_routine_gemm_" + precision_string + ".json",
- device, platform, metadata, scores);
+ // Run the tuners for the XGEMM routines
+ TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmRoutine<T>,
+ 64, 2048, 64, 1, num_runs,
+ "gemm", "GemmRoutine", "gemm_routine", "XGEMM_MIN_INDIRECT_SIZE");
+ //TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmBatchedRoutine<T, 30>,
+ // 16, 128, 32, 30, num_runs,
+ // "gemmbatched", "GemmRoutine", "gemm_routine_2", "XGEMMBATCHED_MIN_INDIRECT_SIZE");
+ //TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmStridedBatchedRoutine<T, 30>,
+ // 16, 128, 32, 30, num_runs,
+ // "gemmstridedbatched", "GemmRoutine", "gemm_routine_3", "XGEMMSTRIDEDBATCHED_MIN_INDIRECT_SIZE");
printf("* Completed tuning process\n");
printf("\n");