summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-12-31 13:19:58 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-12-31 13:19:58 +0100
commit153ac06cf262d2680d0152933156b1d1e15b3f86 (patch)
tree99e52fa883c8da5e19f96a3a1ed33bc33ad5f7d3 /src
parentb894993967529f8b878e376d9dfe146e7fee26aa (diff)
Added the forgotten batch dimension to the tuner to get correct kernel executions
Diffstat (limited to 'src')
-rw-r--r--src/tuning/kernels/xconvgemm.hpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/src/tuning/kernels/xconvgemm.hpp b/src/tuning/kernels/xconvgemm.hpp
index 9ba70f5e..10dc8ba6 100644
--- a/src/tuning/kernels/xconvgemm.hpp
+++ b/src/tuning/kernels/xconvgemm.hpp
@@ -86,10 +86,10 @@ TunerSettings XConvGemmGetTunerSettings(const int, const Arguments<T> &args) {
settings.outputs = {4};
// Sets the base thread configuration
- settings.global_size = {num_patches, args.num_kernels};
+ settings.global_size = {num_patches, args.num_kernels, args.batch_count};
settings.global_size_ref = settings.global_size;
- settings.local_size = {1, 1};
- settings.local_size_ref = {8, 8};
+ settings.local_size = {1, 1, 1};
+ settings.local_size_ref = {8, 8, 1};
// Transforms the thread configuration based on the parameters
settings.mul_local = {{"MDIMCD", "NDIMCD"}};
@@ -161,12 +161,12 @@ void XConvGemmSetArguments(const int, Kernel &kernel, const Arguments<T> &args,
kernel.SetArgument(1, static_cast<int>(args.num_kernels));
kernel.SetArgument(2, static_cast<int>(patch_size));
kernel.SetArgument(3, buffers[3]()); // 3 == B matrix ==> kernel buffer
- kernel.SetArgument(4, 0); // c_offset
+ kernel.SetArgument(4, 0); // kernel offset
kernel.SetArgument(5, buffers[4]()); // 4 == C matrix ==> result buffer
- kernel.SetArgument(6, 0); // c_offset
+ kernel.SetArgument(6, 0); // result offset
kernel.SetArgument(7, static_cast<int>(result_stride));
kernel.SetArgument(8, buffers[2]()); // 2 == A matrix ==> image buffer
- kernel.SetArgument(9, 0); // c_offset
+ kernel.SetArgument(9, 0); // image offset
kernel.SetArgument(10, static_cast<int>(args.height));
kernel.SetArgument(11, static_cast<int>(args.width));
kernel.SetArgument(12, static_cast<int>(args.channels));