summaryrefslogtreecommitdiff
path: root/src/routines/levelx
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-01-08 21:07:01 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-01-08 21:07:01 +0100
commit99a4df88a6d808ea77c9116ce63621503c00b57a (patch)
treed8f8bc1b3884c0340df9f6d95b4837ed3dff8deb /src/routines/levelx
parent13f0f6fc6e612a5f77c6fd78b983f1b2bb8e36b6 (diff)
Implemented the in-direct version of the strided-batched GEMM kernel
Diffstat (limited to 'src/routines/levelx')
-rw-r--r--src/routines/levelx/xgemmstridedbatched.cpp65
1 files changed, 22 insertions, 43 deletions
diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp
index ddf7d878..affbceee 100644
--- a/src/routines/levelx/xgemmstridedbatched.cpp
+++ b/src/routines/levelx/xgemmstridedbatched.cpp
@@ -112,7 +112,7 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
const size_t b_one, const size_t b_two,
const size_t c_one, const size_t c_two,
const size_t batch_count) {
- /* TODO
+
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(Ceil(m, db_["MWG"]), db_["VWM"]);
const auto n_ceiled = Ceil(Ceil(n, db_["NWG"]), db_["VWN"]);
@@ -124,18 +124,10 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
- // Sets the "internal" offsets, i.e. the perfect offsets
- auto a_offsets_i = 0;//std::vector<int>(batch_count);
- auto b_offsets_i = 0;//std::vector<int>(batch_count);
- auto c_offsets_i = 0;//std::vector<int>(batch_count);
-
// Determines whether or not temporary matrices are needed
- auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offsets == a_offsets_i &&
- !a_do_transpose && !a_conjugate;
- auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offsets == b_offsets_i &&
- !b_do_transpose && !b_conjugate;
- auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offsets == c_offsets_i &&
- !c_do_transpose;
+ auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && !a_do_transpose && !a_conjugate;
+ auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && !b_do_transpose && !b_conjugate;
+ auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && !c_do_transpose;
// Creates the temporary matrices
const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, batch_count * a_one_i * a_two_i);
@@ -150,43 +142,31 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
// to fill it up until it reaches a certain multiple of size (kernel parameter dependent). In
// case nothing has to be done, these kernels can be skipped.
if (!a_no_temp) {
- auto a_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
- auto a_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
- a_offsets_device.Write(queue_, batch_count, a_offsets);
- a_offsets_i_device.Write(queue_, batch_count, a_offsets_i);
auto eventProcessA = Event();
- PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
- a_one, a_two, a_ld, a_offsets_device, a_buffer,
- a_one_i, a_two_i, a_one_i, a_offsets_i_device, a_temp,
- program_, true, a_do_transpose, a_conjugate, batch_count);
+ PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
+ a_one, a_two, a_ld, a_offset, a_stride, a_buffer,
+ a_one_i, a_two_i, a_one_i, 0, a_one_i * a_two_i, a_temp,
+ program_, true, a_do_transpose, a_conjugate, batch_count);
eventWaitList.push_back(eventProcessA);
}
// As above, but now for matrix B
if (!b_no_temp) {
- auto b_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
- auto b_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
- b_offsets_device.Write(queue_, batch_count, b_offsets);
- b_offsets_i_device.Write(queue_, batch_count, b_offsets_i);
auto eventProcessB = Event();
- PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
- b_one, b_two, b_ld, b_offsets_device, b_buffer,
- b_one_i, b_two_i, b_one_i, b_offsets_i_device, b_temp,
- program_, true, b_do_transpose, b_conjugate, batch_count);
+ PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
+ b_one, b_two, b_ld, b_offset, b_stride, b_buffer,
+ b_one_i, b_two_i, b_one_i, 0, b_one_i * b_two_i, b_temp,
+ program_, true, b_do_transpose, b_conjugate, batch_count);
eventWaitList.push_back(eventProcessB);
}
// As above, but now for matrix C
- auto c_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
- auto c_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
if (!c_no_temp) {
- c_offsets_device.Write(queue_, batch_count, c_offsets);
- c_offsets_i_device.Write(queue_, batch_count, c_offsets_i);
auto eventProcessC = Event();
- PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
- c_one, c_two, c_ld, c_offsets_device, c_buffer,
- c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
- program_, true, c_do_transpose, false, batch_count);
+ PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
+ c_one, c_two, c_ld, c_offset, c_stride, c_buffer,
+ c_one_i, c_two_i, c_one_i, 0, c_one_i * c_two_i, c_temp,
+ program_, true, c_do_transpose, false, batch_count);
eventWaitList.push_back(eventProcessC);
}
@@ -197,8 +177,8 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
kernel.SetArgument(0, static_cast<int>(m_ceiled));
kernel.SetArgument(1, static_cast<int>(n_ceiled));
kernel.SetArgument(2, static_cast<int>(k_ceiled));
- kernel.SetArgument(3, alpha);
- kernel.SetArgument(4, beta);
+ kernel.SetArgument(3, GetRealArg(alpha));
+ kernel.SetArgument(4, GetRealArg(beta));
kernel.SetArgument(5, a_temp());
kernel.SetArgument(6, static_cast<int>(a_one_i));
kernel.SetArgument(7, static_cast<int>(a_two_i));
@@ -225,12 +205,11 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
// Runs the post-processing kernel if needed
if (!c_no_temp) {
eventWaitList.push_back(eventKernel);
- PadCopyTransposeMatrixBatched(queue_, device_, db_, event_, eventWaitList,
- c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
- c_one, c_two, c_ld, c_offsets_device, c_buffer,
- program_, false, c_do_transpose, false, batch_count);
+ PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, event_, eventWaitList,
+ c_one_i, c_two_i, c_one_i, 0, c_one_i * c_two_i, c_temp,
+ c_one, c_two, c_ld, c_offset, c_stride, c_buffer,
+ program_, false, c_do_transpose, false, batch_count);
}
- */
}
// =================================================================================================