diff options
Diffstat (limited to 'src/clblast.cpp')
-rw-r--r-- | src/clblast.cpp | 57 |
1 files changed, 38 insertions, 19 deletions
diff --git a/src/clblast.cpp b/src/clblast.cpp index e45f504a..180693e7 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2218,79 +2218,94 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose, // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template <typename T> -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem im_buffer, const size_t im_offset, cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event) { try { auto queue_cpp = Queue(*queue); auto routine = Xim2col<T>(queue_cpp, event); - routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer<T>(im_buffer), im_offset, Buffer<T>(col_buffer), col_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<float2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<double2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template <typename T> -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const cl_mem col_buffer, const size_t col_offset, cl_mem im_buffer, const size_t im_offset, cl_command_queue* queue, cl_event* event) { try { auto queue_cpp = Queue(*queue); auto routine = Xcol2im<T>(queue_cpp, event); - routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoCol2im(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer<T>(col_buffer), col_offset, Buffer<T>(im_buffer), im_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<float2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<double2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template <typename T> -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, const cl_mem kernel_buffer, const size_t kernel_offset, cl_mem result_buffer, const size_t result_offset, @@ -2298,24 +2313,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid try { auto queue_cpp = Queue(*queue); auto routine = Xconvgemm<T>(queue_cpp, event); - routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + routine.DoConvgemm(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, Buffer<T>(im_buffer), im_offset, Buffer<T>(kernel_buffer), kernel_offset, Buffer<T>(result_buffer), result_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, cl_command_queue*, cl_event*); -template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const cl_mem, const size_t, const cl_mem, const size_t, cl_mem, const size_t, |