From 2d1f6ba7fe842ba938490fc599b6ebd209b6560b Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sun, 6 May 2018 11:35:34 +0200 Subject: Added convgemm skeleton, test infrastructure, and first reference implementation --- test/correctness/routines/levelx/xconvgemm.cpp | 26 ++++++++++++++++++++++++++ test/correctness/testblas.hpp | 10 ++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 test/correctness/routines/levelx/xconvgemm.cpp (limited to 'test/correctness') diff --git a/test/correctness/routines/levelx/xconvgemm.cpp b/test/correctness/routines/levelx/xconvgemm.cpp new file mode 100644 index 00000000..77a0f543 --- /dev/null +++ b/test/correctness/routines/levelx/xconvgemm.cpp @@ -0,0 +1,26 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// ================================================================================================= + +#include "test/correctness/testblas.hpp" +#include "test/routines/levelx/xconvgemm.hpp" + +// Main function (not within the clblast namespace) +int main(int argc, char *argv[]) { + auto errors = size_t{0}; + errors += clblast::RunTests, float, float>(argc, argv, false, "SCONVGEMM"); + errors += clblast::RunTests, double, double>(argc, argv, true, "DCONVGEMM"); + errors += clblast::RunTests, clblast::float2, clblast::float2>(argc, argv, true, "CCONVGEMM"); + errors += clblast::RunTests, clblast::double2, clblast::double2>(argc, argv, true, "ZCONVGEMM"); + errors += clblast::RunTests, clblast::half, clblast::half>(argc, argv, true, "HCONVGEMM"); + if (errors > 0) { return 1; } else { return 0; } +} + +// ================================================================================================= diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index 54b2d6f8..1d1d2ca9 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -60,6 +60,7 @@ class TestBlas: public Tester { static const std::vector kDilationSizes; static const std::vector kKernelSizes; static const std::vector kBatchCounts; + static const std::vector kNumKernels; const std::vector kOffsets; const std::vector kAlphaValues; const std::vector kBetaValues; @@ -136,6 +137,7 @@ template const std::vector TestBlas::kBatc template const std::vector TestBlas::kPadSizes = { 0, 1 }; template const std::vector TestBlas::kDilationSizes = { 1, 2 }; template const std::vector TestBlas::kKernelSizes = { 1, 3 }; +template const std::vector TestBlas::kNumKernels = { 1, 2 }; // Test settings for the invalid tests template const std::vector TestBlas::kInvalidIncrements = { 0, 1 }; @@ -241,6 +243,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na auto dilation_hs = std::vector{args.dilation_h}; auto dilation_ws = std::vector{args.dilation_w}; auto batch_counts = std::vector{args.batch_count}; + auto num_kernelss = std::vector{args.num_kernels}; auto x_sizes = std::vector{args.x_size}; auto y_sizes = std::vector{args.y_size}; auto a_sizes = std::vector{args.a_size}; @@ -296,6 +299,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgDilationH) { dilation_hs = tester.kDilationSizes; } if (option == kArgDilationW) { dilation_ws = tester.kDilationSizes; } if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; } + if (option == kArgNumKernels) { num_kernelss = tester.kNumKernels; } if (option == kArgXOffset) { x_sizes = tester.kVecSizes; } if (option == kArgYOffset) { y_sizes = tester.kVecSizes; } @@ -350,8 +354,10 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na for (auto &dilation_h: dilation_hs) { r_args.dilation_h = dilation_h; for (auto &dilation_w: dilation_ws) { r_args.dilation_w = dilation_w; for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count; - C::SetSizes(r_args, tester.queue_); - regular_test_vector.push_back(r_args); + for (auto &num_kernels: num_kernelss) { r_args.num_kernels = num_kernels; + C::SetSizes(r_args, tester.queue_); + regular_test_vector.push_back(r_args); + } } } } -- cgit v1.2.3