summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt7
-rw-r--r--src/tuning/routines/xgemm.cpp136
-rw-r--r--src/utilities/timing.hpp52
3 files changed, 195 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d3b202c2..73b47637 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -377,6 +377,13 @@ if(TUNERS)
target_include_directories(clblast_tuner_${KERNEL} PUBLIC ${CLTUNE_INCLUDE_DIRS})
install(TARGETS clblast_tuner_${KERNEL} DESTINATION bin)
endforeach()
+ set(ROUTINE_TUNERS xgemm)
+ foreach(ROUTINE_TUNER ${ROUTINE_TUNERS})
+ add_executable(clblast_tuner_routine_${ROUTINE_TUNER} ${TUNERS_COMMON} src/tuning/routines/${ROUTINE_TUNER}.cpp)
+ target_link_libraries(clblast_tuner_routine_${ROUTINE_TUNER} clblast ${CLTUNE_LIBRARIES} ${API_LIBRARIES})
+ target_include_directories(clblast_tuner_routine_${ROUTINE_TUNER} PUBLIC ${CLTUNE_INCLUDE_DIRS})
+ install(TARGETS clblast_tuner_routine_${ROUTINE_TUNER} DESTINATION bin)
+ endforeach()
# Adds 'alltuners' target: runs all tuners for all precisions
set(ALLTUNERS )
diff --git a/src/tuning/routines/xgemm.cpp b/src/tuning/routines/xgemm.cpp
new file mode 100644
index 00000000..9590323a
--- /dev/null
+++ b/src/tuning/routines/xgemm.cpp
@@ -0,0 +1,136 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file tunes the Xgemm routine at a high-level: choosing between the direct (single-kernel)
+// and the in-direct (kernel plus pre/post-processing) methods.
+//
+// =================================================================================================
+
+#include <exception>
+#include <string>
+#include <vector>
+#include <assert.h>
+
+#include "utilities/utilities.hpp"
+#include "utilities/timing.hpp"
+
+namespace clblast {
+// =================================================================================================
+
+template <typename T>
+void RunGemmRoutine(const size_t value, const Queue& queue, const std::vector<Buffer<T>>& buffers) {
+ auto queue_plain = queue();
+ auto event = cl_event{};
+ auto status = Gemm(Layout::kRowMajor, Transpose::kNo, Transpose::kNo,
+ value, value, value, ConstantOne<T>(),
+ buffers[0](), 0, value,
+ buffers[1](), 0, value, ConstantOne<T>(),
+ buffers[2](), 0, value,
+ &queue_plain, &event);
+ if (status != StatusCode::kSuccess) {
+ throw RuntimeError("Gemm failed with status " + ToString(status));
+ }
+ clWaitForEvents(1, &event);
+ clReleaseEvent(event);
+}
+
+template <typename T>
+void ForceSelectIndirectFrom(const size_t minimum_size, const Device &device) {
+ const auto override_status = OverrideParameters(device(), "KernelSelection", PrecisionValue<T>(),
+ {{"XGEMM_MIN_INDIRECT_SIZE", minimum_size}});
+ if (override_status != StatusCode::kSuccess) {
+ throw RuntimeError("OverrideParameters failed with status " + ToString(override_status));
+ }
+}
+
+template <typename T>
+void TuneXgemm(int argc, char* argv[]) {
+ auto command_line_args = RetrieveCommandLineArguments(argc, argv);
+ auto help = std::string{"* Options given/available:\n"};
+ const auto platform_id = GetArgument(command_line_args, help, kArgPlatform, ConvertArgument(std::getenv("CLBLAST_PLATFORM"), size_t{0}));
+ const auto device_id = GetArgument(command_line_args, help, kArgDevice, ConvertArgument(std::getenv("CLBLAST_DEVICE"), size_t{0}));
+ const auto precision = GetArgument(command_line_args, help, kArgPrecision, Precision::kSingle);
+ const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, size_t{10});
+ fprintf(stdout, "%s\n", help.c_str());
+
+ // Values for m, n, and k
+ const auto from = size_t{64};
+ const auto to = size_t{1024};
+ const auto step = size_t{64};
+
+ // OpenCL initialisation
+ const auto platform = Platform(platform_id);
+ const auto device = Device(platform, device_id);
+ if (!PrecisionSupported<T>(device)) {
+ printf("* Unsupported precision, skipping this tuning run\n\n");
+ return;
+ }
+ const auto context = Context(device);
+ const auto queue = Queue(context, device);
+
+ // Buffers
+ auto a_mat = Buffer<T>(context, to * to);
+ auto b_mat = Buffer<T>(context, to * to);
+ auto c_mat = Buffer<T>(context, to * to);
+ auto buffers = std::vector<Buffer<T>>{a_mat, b_mat, c_mat};
+
+ // In-direct version
+ printf("[----------] Testing the in-direct GEMM routine for m=n=k\n");
+ ForceSelectIndirectFrom<T>(0, device);
+ const auto indirect = TimeRoutine(from, to, step, num_runs, queue, buffers, RunGemmRoutine<T>);
+
+ // Direct version
+ printf("[----------] Testing the direct GEMM routine for m=n=k\n");
+ ForceSelectIndirectFrom<T>(to * to * to + 1, device);
+ const auto direct = TimeRoutine(from, to, step, num_runs, queue, buffers, RunGemmRoutine<T>);
+
+ // Results
+ printf("[----------] Collecting results\n");
+ assert(indirect.size() == direct.size());
+ for (auto i = size_t{0}; i < indirect.size(); ++i) {
+ assert(indirect[i].first == direct[i].first);
+ const auto value = indirect[i].first;
+ const auto gflops_indirect = (2 * value * value * value) / (indirect[i].second * 1.0e6);
+ const auto gflops_direct = (2 * value * value * value) / (direct[i].second * 1.0e6);
+ printf("[ -------> ] %7zu %8.2lf %8.2lf\n", value, gflops_indirect, gflops_direct);
+ }
+
+ // Outputs the results as JSON to disk, including some meta-data
+ const auto precision_string = std::to_string(static_cast<size_t>(precision));
+ auto metadata = std::vector<std::pair<std::string,std::string>>{
+ {"kernel_family", "gemm_routine"},
+ {"precision", precision_string},
+ };
+ PrintTimingsToFileAsJSON("clblast_routine_gemm_" + precision_string + ".json",
+ device, platform, metadata);
+
+}
+
+// =================================================================================================
+} // namespace clblast
+
+// Shortcuts to the clblast namespace
+using half = clblast::half;
+using float2 = clblast::float2;
+using double2 = clblast::double2;
+
+// Main function (not within the clblast namespace)
+int main(int argc, char *argv[]) {
+ const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
+ switch(clblast::GetPrecision(command_line_args)) {
+ case clblast::Precision::kHalf: clblast::TuneXgemm<half>(argc, argv); break;
+ case clblast::Precision::kSingle: clblast::TuneXgemm<float>(argc, argv); break;
+ case clblast::Precision::kDouble: clblast::TuneXgemm<double>(argc, argv); break;
+ case clblast::Precision::kComplexSingle: clblast::TuneXgemm<float2>(argc, argv); break;
+ case clblast::Precision::kComplexDouble: clblast::TuneXgemm<double2>(argc, argv); break;
+ }
+ return 0;
+}
+
+// =================================================================================================
diff --git a/src/utilities/timing.hpp b/src/utilities/timing.hpp
index 3d66de2a..4622aa99 100644
--- a/src/utilities/timing.hpp
+++ b/src/utilities/timing.hpp
@@ -14,14 +14,20 @@
#ifndef CLBLAST_TIMING_H_
#define CLBLAST_TIMING_H_
+#include <cstdio>
+#include <utility>
#include <vector>
+#include <algorithm>
#include <chrono>
+#include "utilities/utilities.hpp"
+
namespace clblast {
// =================================================================================================
template <typename F>
double TimeFunction(const size_t num_runs, F const &function) {
+ function(); // warm-up
auto timings = std::vector<double>(num_runs);
for (auto &timing: timings) {
const auto start_time = std::chrono::steady_clock::now();
@@ -33,6 +39,52 @@ double TimeFunction(const size_t num_runs, F const &function) {
}
// =================================================================================================
+
+using Timing = std::pair<size_t, double>;
+
+template <typename T, typename F>
+std::vector<Timing> TimeRoutine(const size_t from, const size_t to, const size_t step,
+ const size_t num_runs, const Queue& queue,
+ const std::vector<Buffer<T>>& buffers, F const &routine) {
+ auto timings = std::vector<Timing>();
+ for (auto value = from; value < to; value += step) {
+ printf("[ RUN ] Running with value %zu\n", value);
+ try {
+ const auto FunctionToTune = [&]() { routine(value, queue, buffers); };
+ const auto time_ms = TimeFunction(num_runs, FunctionToTune);
+ printf("[ OK ] Took %.2lf ms\n", time_ms);
+ timings.push_back({value, time_ms});
+ }
+ catch (...) {
+ printf("[ ERROR ] Exception caught\n");
+ timings.push_back({value, -1.0}); // invalid
+ }
+ }
+ return timings;
+}
+
+// =================================================================================================
+
+void PrintTimingsToFileAsJSON(const std::string &filename,
+ const Device& device, const Platform& platform,
+ const std::vector<std::pair<std::string,std::string>> &descriptions) {
+ auto file = fopen(filename.c_str(), "w");
+ fprintf(file, "{\n");
+ for (auto &description: descriptions) {
+ fprintf(file, " \"%s\": \"%s\",\n", description.first.c_str(), description.second.c_str());
+ }
+ fprintf(file, " \"platform_version\": \"%s\",\n", platform.Version().c_str());
+ fprintf(file, " \"device_name\": \"%s\",\n", GetDeviceName(device).c_str());
+ fprintf(file, " \"device_vendor\": \"%s\",\n", platform.Vendor().c_str());
+ fprintf(file, " \"device_type\": \"%s\",\n", device.Type().c_str());
+ fprintf(file, " \"device_architecture\": \"%s\",\n", GetDeviceArchitecture(device).c_str());
+ fprintf(file, " \"device_core_clock\": \"%zu\",\n", device.CoreClock());
+ fprintf(file, " \"device_compute_units\": \"%zu\",\n", device.ComputeUnits());
+ fprintf(file, "}\n");
+ fclose(file);
+}
+
+// =================================================================================================
} // namespace clblast
// CLBLAST_TIMING_H_