summaryrefslogtreecommitdiff
path: root/src/clblast_c.cpp
diff options
context:
space:
mode:
authorKoichi Akabe <vbkaisetsu@gmail.com>2018-11-12 10:12:07 +0900
committerKoichi Akabe <vbkaisetsu@gmail.com>2018-11-12 10:12:07 +0900
commit032e3b0cc00a15dd2af8b4fb82d261eb7b086e26 (patch)
treecdcf4d0fc342c9ff92ee7ab3f75b0cdeced46e96 /src/clblast_c.cpp
parent90112618daa0d6b24ae3e53203a636d2e908dfba (diff)
Add kernel_mode option to im2col, col2im, and convgemm functions
Diffstat (limited to 'src/clblast_c.cpp')
-rw-r--r--src/clblast_c.cpp78
1 files changed, 52 insertions, 26 deletions
diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp
index 645a69b1..a224230a 100644
--- a/src/clblast_c.cpp
+++ b/src/clblast_c.cpp
@@ -3613,65 +3613,75 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran
}
// IM2COL
-CLBlastStatusCode CLBlastSim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastSim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastDim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastCim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastCim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastZim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastZim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastHim2col(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem im_buffer, const size_t im_offset,
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Im2col<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Im2col<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
im_buffer, im_offset,
col_buffer, col_offset,
queue, event)
@@ -3680,65 +3690,75 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con
}
// COL2IM
-CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastScol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastDcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastCcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<float2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastZcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<double2>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
+CLBlastStatusCode CLBlastHcol2im(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const cl_mem col_buffer, const size_t col_offset,
cl_mem im_buffer, const size_t im_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Col2im<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ clblast::Col2im<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
col_buffer, col_offset,
im_buffer, im_offset,
queue, event)
@@ -3747,14 +3767,16 @@ CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, con
}
// CONVGEMM
-CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastSconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<float>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<float>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
@@ -3762,14 +3784,16 @@ CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, c
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastDconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<double>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<double>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
@@ -3777,14 +3801,16 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
-CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
+CLBlastStatusCode CLBlastHconvgemm(const CLBlastKernelMode kernel_mode,
+ const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
- clblast::Convgemm<half>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
+ clblast::Convgemm<half>(static_cast<clblast::KernelMode>(kernel_mode),
+ channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,