summaryrefslogtreecommitdiff
path: root/test/routines
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-10-22 22:12:58 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2018-10-22 22:12:58 +0200
commit44b630fc222c6e22446c20995411994b51bc2f21 (patch)
tree870ca2a65940535551cf1b9569bf21ce2e93732c /test/routines
parentab0178c56bf989e3399a1a9738887fb59d0496ed (diff)
Some name changes in im2col code
Diffstat (limited to 'test/routines')
-rw-r--r--test/routines/levelx/xim2col.hpp26
1 files changed, 13 insertions, 13 deletions
diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp
index 092e251d..9fd2af0c 100644
--- a/test/routines/levelx/xim2col.hpp
+++ b/test/routines/levelx/xim2col.hpp
@@ -39,20 +39,20 @@ public:
static std::vector<std::string> BuffersOut() { return {kBufMatB}; }
// Describes how to obtain the sizes of the buffers
- static size_t OutputHeight(const Arguments<T> &args) {
+ static size_t ColHeight(const Arguments<T> &args) {
const auto size = args.height + 2 * args.pad_h;
const auto padding = args.dilation_h * (args.kernel_h - 1) + 1;
if (size >= padding) { return (size - padding) / args.stride_h + 1; }
return 1;
}
- static size_t OutputWidth(const Arguments<T> &args) {
+ static size_t ColWidth(const Arguments<T> &args) {
const auto size = args.width + 2 * args.pad_w;
const auto padding = args.dilation_w * (args.kernel_w - 1) + 1;
if (size >= padding) { return (size - padding) / args.stride_w + 1; }
return 1;
}
static size_t NumPatches(const Arguments<T> &args) {
- return OutputHeight(args) * OutputWidth(args) * args.channels;
+ return ColHeight(args) * ColWidth(args) * args.channels;
}
static size_t GetSizeA(const Arguments<T> &args) {
return args.height * args.width * args.channels + args.a_offset;
@@ -156,13 +156,13 @@ public:
template <typename T>
StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host) {
- const auto output_h = TestXim2col<T>::OutputHeight(args);
- const auto output_w = TestXim2col<T>::OutputWidth(args);
+ const auto col_h = TestXim2col<T>::ColHeight(args);
+ const auto col_w = TestXim2col<T>::ColWidth(args);
for (auto c_id = size_t{0}; c_id < args.channels; ++c_id) { // input channels
for (auto kh_id = size_t{0}; kh_id < args.kernel_h; ++kh_id) { // kernel height
for (auto kw_id = size_t{0}; kw_id < args.kernel_w; ++kw_id) { // kernel width
- for (auto h_id = size_t{0}; h_id < output_h; ++h_id) { // image height
- for (auto w_id = size_t{0}; w_id < output_w; ++w_id) { // image width
+ for (auto h_id = size_t{0}; h_id < col_h; ++h_id) { // image height
+ for (auto w_id = size_t{0}; w_id < col_w; ++w_id) { // image width
// Retrieves the input value
const auto h_index = kh_id * args.dilation_h + args.stride_h * h_id - args.pad_h;
@@ -170,16 +170,16 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
auto val = ConstantZero<T>();
if (h_index >= 0 && h_index < args.height &&
w_index >= 0 && w_index < args.width) {
- const auto input_index = w_index + args.width * (h_index + args.height * c_id);
- val = buffers_host.a_mat[input_index + args.a_offset];
+ const auto im_index = w_index + args.width * (h_index + args.height * c_id);
+ val = buffers_host.a_mat[im_index + args.a_offset];
}
// Sets the output value
const auto kernel_index = kw_id + args.kernel_w * kh_id;
- const auto patch_index = w_id + output_w * h_id;
- const auto output_index = patch_index + kernel_index * output_w * output_h +
- c_id * output_w * output_h * args.kernel_h * args.kernel_w;
- buffers_host.b_mat[output_index + args.b_offset] = val;
+ const auto patch_index = w_id + col_w * h_id;
+ const auto col_index = patch_index + kernel_index * col_w * col_h +
+ c_id * col_w * col_h * args.kernel_h * args.kernel_w;
+ buffers_host.b_mat[col_index + args.b_offset] = val;
}
}
}