summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/kernels/levelx/xconvgemm_part2.opencl32
-rw-r--r--test/correctness/testblas.hpp2
2 files changed, 22 insertions, 12 deletions
diff --git a/src/kernels/levelx/xconvgemm_part2.opencl b/src/kernels/levelx/xconvgemm_part2.opencl
index f9b78974..46a72711 100644
--- a/src/kernels/levelx/xconvgemm_part2.opencl
+++ b/src/kernels/levelx/xconvgemm_part2.opencl
@@ -84,7 +84,6 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
// The faster version of GEMM is not allowed on the (incomplete) borders. Therefore, this section
// processes only the main parts: output blocks of WGD by WGD.
- #if defined(CONVGEMM_WITH_IM2COL) // TEMP: To be implemented for other case as well
if ((idm < (num_patches/WGD)*WGD) && (idn < (num_kernels/WGD)*WGD)) {
// Loops over all complete workgroup tiles (K-dimension)
@@ -92,12 +91,19 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
for (; kwg < (patch_size/WGD) * WGD; kwg += WGD) {
// Loads data: off-chip --> local (matrix A and B)
- if (num_patches % VWMD == 0 && col_offset_batch % VWMD == 0) {
- GlobalToLocalDirectA(colgm, alm, num_patches, col_offset_batch, kwg, false, false);
- }
- else {
- GlobalToLocalScalarA(colgms, alm, num_patches, col_offset_batch, kwg, false, false);
- }
+ #if defined(CONVGEMM_WITH_IM2COL)
+ if (num_patches % VWMD == 0 && col_offset_batch % VWMD == 0) {
+ GlobalToLocalDirectA(colgm, alm, num_patches, col_offset_batch, kwg, false, false);
+ }
+ else {
+ GlobalToLocalScalarA(colgms, alm, num_patches, col_offset_batch, kwg, false, false);
+ }
+ #else
+ GlobalToLocalCheckedImage(imagegm, alm, image_offset_batch, h_id, w_id, kwg,
+ input_h, input_w, channels, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w);
+ #endif
if (patch_size % VWND == 0 && kernel_offset % VWND == 0) {
GlobalToLocalDirectB(kernelgm, blm, patch_size, kernel_offset, kwg, true, false);
}
@@ -141,7 +147,14 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
// Loads data: off-chip --> private (matrix A and B)
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
- apd[_mi] = GlobalToPrivateDirectA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false);
+ #if defined(CONVGEMM_WITH_IM2COL)
+ apd[_mi] = GlobalToPrivateDirectA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false);
+ #else
+ apd[_mi] = GlobalToPrivateCheckedImage(imagegm, image_offset_batch, h_id, w_id, kwg,
+ input_h, input_w, channels, kernel_h, kernel_w,
+ pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w);
+ #endif
}
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
@@ -171,9 +184,6 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
// Simple but slower version for the parts on the edge (incomplete tiles in M and N-dimensions)
else {
- #else // TEMP, to be implemented
- { // TEMP, to be implemented
- #endif // TEMP, to be implemented
// Loops over all complete workgroup tiles (K-dimension)
int kwg = 0;
for (; kwg < (patch_size/WGD) * WGD; kwg+=WGD) {
diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp
index 0dc8584e..6c0abab6 100644
--- a/test/correctness/testblas.hpp
+++ b/test/correctness/testblas.hpp
@@ -139,7 +139,7 @@ template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBatc
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0, 1 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kDilationSizes = { 1, 2 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kKernelSizes = { 1, 3 };
-template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kNumKernels = { 1, 2 };
+template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kNumKernels = { 1, 67 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kStrideValues = { 1, 3 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kChannelValues = { 1, 4 };