From 469c346a8e0b2be9f3c736c39760548b1749918c Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Thu, 1 Nov 2018 21:44:21 +0100 Subject: Fixed half-precision tests for im2col and col2im --- test/routines/levelx/xcol2im.hpp | 20 ++++++++++++++++++++ test/routines/levelx/xim2col.hpp | 20 ++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp index 13e70a5f..69e8c6c1 100644 --- a/test/routines/levelx/xcol2im.hpp +++ b/test/routines/levelx/xcol2im.hpp @@ -197,6 +197,26 @@ StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) return StatusCode::kSuccess; } +// Half-precision version calling the above reference implementation after conversions +template <> +StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) { + auto a_buffer2 = HalfToFloatBuffer(buffers_host.a_mat); + auto b_buffer2 = HalfToFloatBuffer(buffers_host.b_mat); + auto dummy = std::vector(0); + auto buffers2 = BuffersHost{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy}; + auto args2 = Arguments(); + args2.a_size = args.a_size; args2.b_size = args.b_size; + args2.channels = args.channels; args2.height = args.height; args2.width = args.width; + args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w; + args2.pad_h = args.pad_h; args2.pad_w = args.pad_w; + args2.stride_h = args.stride_h; args2.stride_w = args.stride_w; + args2.dilation_h = args.dilation_h; args2.dilation_w = args.dilation_w; + args2.a_offset = args.a_offset; args2.b_offset = args.b_offset; + auto status = RunReference(args2, buffers2); + FloatToHalfBuffer(buffers_host.a_mat, buffers2.a_mat); + return status; +} + // ================================================================================================= } // namespace clblast diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp index 9fd2af0c..acf7998b 100644 --- a/test/routines/levelx/xim2col.hpp +++ b/test/routines/levelx/xim2col.hpp @@ -188,6 +188,26 @@ StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) return StatusCode::kSuccess; } +// Half-precision version calling the above reference implementation after conversions +template <> +StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) { + auto a_buffer2 = HalfToFloatBuffer(buffers_host.a_mat); + auto b_buffer2 = HalfToFloatBuffer(buffers_host.b_mat); + auto dummy = std::vector(0); + auto buffers2 = BuffersHost{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy}; + auto args2 = Arguments(); + args2.a_size = args.a_size; args2.b_size = args.b_size; + args2.channels = args.channels; args2.height = args.height; args2.width = args.width; + args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w; + args2.pad_h = args.pad_h; args2.pad_w = args.pad_w; + args2.stride_h = args.stride_h; args2.stride_w = args.stride_w; + args2.dilation_h = args.dilation_h; args2.dilation_w = args.dilation_w; + args2.a_offset = args.a_offset; args2.b_offset = args.b_offset; + auto status = RunReference(args2, buffers2); + FloatToHalfBuffer(buffers_host.b_mat, buffers2.b_mat); + return status; +} + // ================================================================================================= } // namespace clblast -- cgit v1.2.3