diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-18 19:41:59 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-18 19:41:59 +0100 |
commit | c3f9371d16a66fa28906a3be9925a646e72ea471 (patch) | |
tree | 086cf7a11e80fc024fa99fb132a6c5fc9eac1c8a /src/tuning/routines/xgemm.cpp | |
parent | bc54411d193b700ef5e0e1ad08c6311597bd433d (diff) |
Made GEMM routine tuning a bit more generic in preparation of possible separate batched tuning arguments
Diffstat (limited to 'src/tuning/routines/xgemm.cpp')
-rw-r--r-- | src/tuning/routines/xgemm.cpp | 68 |
1 files changed, 50 insertions, 18 deletions
diff --git a/src/tuning/routines/xgemm.cpp b/src/tuning/routines/xgemm.cpp index de4bef71..0721ad7c 100644 --- a/src/tuning/routines/xgemm.cpp +++ b/src/tuning/routines/xgemm.cpp @@ -39,6 +39,46 @@ void RunGemmRoutine(const size_t value, const Queue& queue, const std::vector<Bu clReleaseEvent(event); } +template <typename T, size_t batch_count> +void RunGemmBatchedRoutine(const size_t value, const Queue& queue, const std::vector<Buffer<T>>& buffers) { + auto offsets = std::vector<size_t>(batch_count); + auto factors = std::vector<T>(batch_count); + for (auto i = size_t{0}; i < batch_count; ++i) { + offsets[i] = batch_count * value; + factors[i] = ConstantOne<T>(); + } + auto queue_plain = queue(); + auto event = cl_event{}; + auto status = GemmBatched(Layout::kRowMajor, Transpose::kNo, Transpose::kNo, + value, value, value, factors.data(), + buffers[0](), offsets.data(), value, + buffers[1](), offsets.data(), value, factors.data(), + buffers[2](), offsets.data(), value, batch_count, + &queue_plain, &event); + if (status != StatusCode::kSuccess) { + throw RuntimeError("GemmBatched failed with status " + ToString(status)); + } + clWaitForEvents(1, &event); + clReleaseEvent(event); +} + +template <typename T, size_t batch_count> +void RunGemmStridedBatchedRoutine(const size_t value, const Queue& queue, const std::vector<Buffer<T>>& buffers) { + auto queue_plain = queue(); + auto event = cl_event{}; + auto status = GemmStridedBatched(Layout::kRowMajor, Transpose::kNo, Transpose::kNo, + value, value, value, ConstantOne<T>(), + buffers[0](), 0, value, value * value, + buffers[1](), 0, value, value * value, ConstantOne<T>(), + buffers[2](), 0, value, value * value, batch_count, + &queue_plain, &event); + if (status != StatusCode::kSuccess) { + throw RuntimeError("Gemm failed with status " + ToString(status)); + } + clWaitForEvents(1, &event); + clReleaseEvent(event); +} + // ================================================================================================= template <typename T> @@ -61,24 +101,16 @@ void TuneXgemm(int argc, char* argv[]) { const auto context = Context(device); auto queue = Queue(context, device); - // Run the tuners for the XGEMM routine - const auto scores = TuneKernelSelection<T>(device, context, queue, precision, RunGemmRoutine<T>, - num_runs, "Gemm", "XGEMM_MIN_INDIRECT_SIZE"); - const auto xgemm_best = GetBestResult(scores); - const auto xgemm_switching_point = xgemm_best.config.at("XGEMM_MIN_INDIRECT_SIZE"); - const auto xgemm_string = "XGEMM_MIN_INDIRECT_SIZE=" + ToString(xgemm_switching_point); - - // Outputs the results as JSON to disk, including some meta-data - const auto precision_string = std::to_string(static_cast<size_t>(precision)); - auto metadata = std::vector<std::pair<std::string,std::string>>{ - {"kernel_family", "gemm_routine"}, - {"precision", precision_string}, - {"best_kernel", xgemm_best.name}, - {"best_time", ToString(xgemm_best.score)}, - {"best_parameters", xgemm_string} - }; - PrintTimingsToFileAsJSON("clblast_routine_gemm_" + precision_string + ".json", - device, platform, metadata, scores); + // Run the tuners for the XGEMM routines + TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmRoutine<T>, + 64, 2048, 64, 1, num_runs, + "gemm", "GemmRoutine", "gemm_routine", "XGEMM_MIN_INDIRECT_SIZE"); + //TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmBatchedRoutine<T, 30>, + // 16, 128, 32, 30, num_runs, + // "gemmbatched", "GemmRoutine", "gemm_routine_2", "XGEMMBATCHED_MIN_INDIRECT_SIZE"); + //TuneKernelSelection<T>(platform, device, context, queue, precision, RunGemmStridedBatchedRoutine<T, 30>, + // 16, 128, 32, 30, num_runs, + // "gemmstridedbatched", "GemmRoutine", "gemm_routine_3", "XGEMMSTRIDEDBATCHED_MIN_INDIRECT_SIZE"); printf("* Completed tuning process\n"); printf("\n"); |