summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/routines/levelx/xcol2im.hpp20
-rw-r--r--test/routines/levelx/xim2col.hpp20
2 files changed, 40 insertions, 0 deletions
diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp
index 13e70a5f..69e8c6c1 100644
--- a/test/routines/levelx/xcol2im.hpp
+++ b/test/routines/levelx/xcol2im.hpp
@@ -197,6 +197,26 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
return StatusCode::kSuccess;
}
+// Half-precision version calling the above reference implementation after conversions
+template <>
+StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &buffers_host) {
+ auto a_buffer2 = HalfToFloatBuffer(buffers_host.a_mat);
+ auto b_buffer2 = HalfToFloatBuffer(buffers_host.b_mat);
+ auto dummy = std::vector<float>(0);
+ auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy};
+ auto args2 = Arguments<float>();
+ args2.a_size = args.a_size; args2.b_size = args.b_size;
+ args2.channels = args.channels; args2.height = args.height; args2.width = args.width;
+ args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w;
+ args2.pad_h = args.pad_h; args2.pad_w = args.pad_w;
+ args2.stride_h = args.stride_h; args2.stride_w = args.stride_w;
+ args2.dilation_h = args.dilation_h; args2.dilation_w = args.dilation_w;
+ args2.a_offset = args.a_offset; args2.b_offset = args.b_offset;
+ auto status = RunReference(args2, buffers2);
+ FloatToHalfBuffer(buffers_host.a_mat, buffers2.a_mat);
+ return status;
+}
+
// =================================================================================================
} // namespace clblast
diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp
index 9fd2af0c..acf7998b 100644
--- a/test/routines/levelx/xim2col.hpp
+++ b/test/routines/levelx/xim2col.hpp
@@ -188,6 +188,26 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
return StatusCode::kSuccess;
}
+// Half-precision version calling the above reference implementation after conversions
+template <>
+StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &buffers_host) {
+ auto a_buffer2 = HalfToFloatBuffer(buffers_host.a_mat);
+ auto b_buffer2 = HalfToFloatBuffer(buffers_host.b_mat);
+ auto dummy = std::vector<float>(0);
+ auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy};
+ auto args2 = Arguments<float>();
+ args2.a_size = args.a_size; args2.b_size = args.b_size;
+ args2.channels = args.channels; args2.height = args.height; args2.width = args.width;
+ args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w;
+ args2.pad_h = args.pad_h; args2.pad_w = args.pad_w;
+ args2.stride_h = args.stride_h; args2.stride_w = args.stride_w;
+ args2.dilation_h = args.dilation_h; args2.dilation_w = args.dilation_w;
+ args2.a_offset = args.a_offset; args2.b_offset = args.b_offset;
+ auto status = RunReference(args2, buffers2);
+ FloatToHalfBuffer(buffers_host.b_mat, buffers2.b_mat);
+ return status;
+}
+
// =================================================================================================
} // namespace clblast