From df7638c3058c59b173f04cadef78c1955ac008f6 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 26 Feb 2017 14:31:05 +0100 Subject: Fixed an out-of-bounds memory access when filling a matrix with a constant --- src/kernels/level3/level3.opencl | 4 ++-- src/routines/common.hpp | 15 ++++++++------- src/routines/level3/xtrsm.cpp | 3 ++- src/routines/levelx/xinvert.cpp | 2 +- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/kernels/level3/level3.opencl b/src/kernels/level3/level3.opencl index 0f5a8607..5ba8cf29 100644 --- a/src/kernels/level3/level3.opencl +++ b/src/kernels/level3/level3.opencl @@ -77,12 +77,12 @@ R"( #if defined(ROUTINE_INVERT) || defined(ROUTINE_TRSM) __kernel __attribute__((reqd_work_group_size(8, 8, 1))) -void FillMatrix(const int n, const int ld, const int offset, +void FillMatrix(const int m, const int n, const int ld, const int offset, __global real* restrict dest, const real_arg arg_value) { const real value = GetRealArg(arg_value); const int id_one = get_global_id(0); const int id_two = get_global_id(1); - if (id_one < ld && id_two < n) { + if (id_one < m && id_two < n) { dest[id_two*ld + id_one + offset] = value; } } diff --git a/src/routines/common.hpp b/src/routines/common.hpp index bdea0086..47d62027 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -38,17 +38,18 @@ template void FillMatrix(Queue &queue, const Device &device, const Program &program, const Database &, EventPointer event, const std::vector &waitForEvents, - const size_t n, const size_t ld, const size_t offset, + const size_t m, const size_t n, const size_t ld, const size_t offset, const Buffer &dest, const T constant_value) { auto kernel = Kernel(program, "FillMatrix"); - kernel.SetArgument(0, static_cast(n)); - kernel.SetArgument(1, static_cast(ld)); - kernel.SetArgument(2, static_cast(offset)); - kernel.SetArgument(3, dest()); - kernel.SetArgument(4, GetRealArg(constant_value)); + kernel.SetArgument(0, static_cast(m)); + kernel.SetArgument(1, static_cast(n)); + kernel.SetArgument(2, static_cast(ld)); + kernel.SetArgument(3, static_cast(offset)); + kernel.SetArgument(4, dest()); + kernel.SetArgument(5, GetRealArg(constant_value)); auto local = std::vector{8, 8}; - auto global = std::vector{Ceil(ld, 8), Ceil(n, 8)}; + auto global = std::vector{Ceil(m, 8), Ceil(n, 8)}; RunKernel(kernel, queue, device, global, local, event, waitForEvents); } diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp index 42855362..b734dd2d 100644 --- a/src/routines/level3/xtrsm.cpp +++ b/src/routines/level3/xtrsm.cpp @@ -91,6 +91,7 @@ void Xtrsm::TrsmColMajor(const Side side, const Triangle triangle, // Creates a copy of B to avoid overwriting input in GEMM while computing output const auto b_size = b_ld * (n - 1) + m + b_offset; const auto x_one = m; + const auto x_two = n; const auto x_size = b_size; const auto x_ld = b_ld; const auto x_offset = b_offset; @@ -105,7 +106,7 @@ void Xtrsm::TrsmColMajor(const Side side, const Triangle triangle, auto eventWaitList = std::vector(); auto fill_matrix_event = Event(); FillMatrix(queue_, device_, program_, db_, fill_matrix_event.pointer(), eventWaitList, - x_one, x_ld, x_offset, x_buffer, ConstantZero()); + x_one, x_two, x_ld, x_offset, x_buffer, ConstantZero()); fill_matrix_event.WaitForCompletion(); // Inverts the diagonal blocks diff --git a/src/routines/levelx/xinvert.cpp b/src/routines/levelx/xinvert.cpp index bcc3706d..5c21d5ce 100644 --- a/src/routines/levelx/xinvert.cpp +++ b/src/routines/levelx/xinvert.cpp @@ -73,7 +73,7 @@ void Xinvert::InvertMatrixDiagonalBlocks(const Layout layout, const Triangle auto event_wait_list = std::vector(); auto fill_matrix_event = Event(); FillMatrix(queue_, device_, program_, db_, fill_matrix_event.pointer(), event_wait_list, - num_blocks * block_size, block_size, 0, dest, ConstantZero()); + block_size, num_blocks * block_size, block_size, 0, dest, ConstantZero()); event_wait_list.push_back(fill_matrix_event); // Inverts the diagonal IB by IB inner blocks of the matrix: one block per work-group -- cgit v1.2.3