diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-09-07 22:04:08 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-09-07 22:04:08 +0200 |
commit | bbb4523b7cc664ad64cc17f7381e6bbfb0874f06 (patch) | |
tree | 2a7432fd718400cae3563d5ffb9f4f1b594c5b96 /test | |
parent | c788e040f7f4e46d9f03644cadb65788fe42571e (diff) |
Added reference implementation for xCONVGEMM for half-precision
Diffstat (limited to 'test')
-rw-r--r-- | test/routines/levelx/xconvgemm.hpp | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/test/routines/levelx/xconvgemm.hpp b/test/routines/levelx/xconvgemm.hpp index 7233f7b6..7fa4e701 100644 --- a/test/routines/levelx/xconvgemm.hpp +++ b/test/routines/levelx/xconvgemm.hpp @@ -214,6 +214,28 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host) return StatusCode::kSuccess; } +// Half-precision version calling the above reference implementation after conversions +template <> +StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &buffers_host) { + auto a_buffer2 = HalfToFloatBuffer(buffers_host.a_mat); + auto b_buffer2 = HalfToFloatBuffer(buffers_host.b_mat); + auto c_buffer2 = HalfToFloatBuffer(buffers_host.c_mat); + auto dummy = std::vector<float>(0); + auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, c_buffer2, dummy, dummy}; + auto args2 = Arguments<float>(); + args2.a_size = args.a_size; args2.b_size = args.b_size; args2.c_size = args.c_size; + args2.channels = args.channels; args2.height = args.height; args2.width = args.width; + args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w; + args2.pad_h = args.pad_h; args2.pad_w = args.pad_w; + args2.stride_h = args.stride_h; args2.stride_w = args.stride_w; + args2.dilation_h = args.dilation_h; args2.dilation_w = args.dilation_w; + args2.num_kernels = args.num_kernels; args2.batch_count = args.batch_count; + args2.a_offset = args.a_offset; args2.b_offset = args.b_offset; args2.c_offset = args.c_offset; + auto status = RunReference(args2, buffers2); + FloatToHalfBuffer(buffers_host.c_mat, buffers2.c_mat); + return status; +} + // ================================================================================================= } // namespace clblast |