diff options
author | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-11-12 10:12:07 +0900 |
---|---|---|
committer | Koichi Akabe <vbkaisetsu@gmail.com> | 2018-11-12 10:12:07 +0900 |
commit | 032e3b0cc00a15dd2af8b4fb82d261eb7b086e26 (patch) | |
tree | cdcf4d0fc342c9ff92ee7ab3f75b0cdeced46e96 /src/clblast_cuda.cpp | |
parent | 90112618daa0d6b24ae3e53203a636d2e908dfba (diff) |
Add kernel_mode option to im2col, col2im, and convgemm functions
Diffstat (limited to 'src/clblast_cuda.cpp')
-rw-r--r-- | src/clblast_cuda.cpp | 57 |
1 files changed, 38 insertions, 19 deletions
diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index 03d995ba..264f360d 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -2314,7 +2314,8 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose, // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL template <typename T> -StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Im2col(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr im_buffer, const size_t im_offset, CUdeviceptr col_buffer, const size_t col_offset, const CUcontext context, const CUdevice device) { @@ -2323,36 +2324,43 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xim2col<T>(queue_cpp, nullptr); - routine.DoIm2col(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoIm2col(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer<T>(im_buffer), im_offset, Buffer<T>(col_buffer), col_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Im2col<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<float2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<double2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Im2col<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM template <typename T> -StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, +StatusCode Col2im(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const CUdeviceptr col_buffer, const size_t col_offset, CUdeviceptr im_buffer, const size_t im_offset, const CUcontext context, const CUdevice device) { @@ -2361,36 +2369,43 @@ StatusCode Col2im(const size_t channels, const size_t height, const size_t width const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xcol2im<T>(queue_cpp, nullptr); - routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + routine.DoCol2im(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, Buffer<T>(col_buffer), col_offset, Buffer<T>(im_buffer), im_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Col2im<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<float2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<double2>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Col2im<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Col2im<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template <typename T> -StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, +StatusCode Convgemm(const KernelMode kernel_mode, + const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const CUdeviceptr im_buffer, const size_t im_offset, const CUdeviceptr kernel_buffer, const size_t kernel_offset, CUdeviceptr result_buffer, const size_t result_offset, @@ -2400,24 +2415,28 @@ StatusCode Convgemm(const size_t channels, const size_t height, const size_t wid const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xconvgemm<T>(queue_cpp, nullptr); - routine.DoConvgemm(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, + routine.DoConvgemm(kernel_mode, + channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count, Buffer<T>(im_buffer), im_offset, Buffer<T>(kernel_buffer), kernel_offset, Buffer<T>(result_buffer), result_offset); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } -template StatusCode PUBLIC_API Convgemm<float>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<float>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<double>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, const CUcontext, const CUdevice); -template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, +template StatusCode PUBLIC_API Convgemm<half>(const KernelMode, + const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const CUdeviceptr, const size_t, const CUdeviceptr, const size_t, CUdeviceptr, const size_t, |