diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-05-06 11:35:34 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-05-06 11:35:34 +0200 |
commit | 2d1f6ba7fe842ba938490fc599b6ebd209b6560b (patch) | |
tree | f1a284e5dc0163b7fed938a3efeb39432b9d3788 /test/correctness | |
parent | 2776d761768295b01a8be7c333dbb337805d7f77 (diff) |
Added convgemm skeleton, test infrastructure, and first reference implementation
Diffstat (limited to 'test/correctness')
-rw-r--r-- | test/correctness/routines/levelx/xconvgemm.cpp | 26 | ||||
-rw-r--r-- | test/correctness/testblas.hpp | 10 |
2 files changed, 34 insertions, 2 deletions
diff --git a/test/correctness/routines/levelx/xconvgemm.cpp b/test/correctness/routines/levelx/xconvgemm.cpp new file mode 100644 index 00000000..77a0f543 --- /dev/null +++ b/test/correctness/routines/levelx/xconvgemm.cpp @@ -0,0 +1,26 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// ================================================================================================= + +#include "test/correctness/testblas.hpp" +#include "test/routines/levelx/xconvgemm.hpp" + +// Main function (not within the clblast namespace) +int main(int argc, char *argv[]) { + auto errors = size_t{0}; + errors += clblast::RunTests<clblast::TestXconvgemm<float>, float, float>(argc, argv, false, "SCONVGEMM"); + errors += clblast::RunTests<clblast::TestXconvgemm<double>, double, double>(argc, argv, true, "DCONVGEMM"); + errors += clblast::RunTests<clblast::TestXconvgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CCONVGEMM"); + errors += clblast::RunTests<clblast::TestXconvgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZCONVGEMM"); + errors += clblast::RunTests<clblast::TestXconvgemm<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HCONVGEMM"); + if (errors > 0) { return 1; } else { return 0; } +} + +// ================================================================================================= diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index 54b2d6f8..1d1d2ca9 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -60,6 +60,7 @@ class TestBlas: public Tester<T,U> { static const std::vector<size_t> kDilationSizes; static const std::vector<size_t> kKernelSizes; static const std::vector<size_t> kBatchCounts; + static const std::vector<size_t> kNumKernels; const std::vector<size_t> kOffsets; const std::vector<U> kAlphaValues; const std::vector<U> kBetaValues; @@ -136,6 +137,7 @@ template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBatc template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0, 1 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kDilationSizes = { 1, 2 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kKernelSizes = { 1, 3 }; +template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kNumKernels = { 1, 2 }; // Test settings for the invalid tests template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kInvalidIncrements = { 0, 1 }; @@ -241,6 +243,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na auto dilation_hs = std::vector<size_t>{args.dilation_h}; auto dilation_ws = std::vector<size_t>{args.dilation_w}; auto batch_counts = std::vector<size_t>{args.batch_count}; + auto num_kernelss = std::vector<size_t>{args.num_kernels}; auto x_sizes = std::vector<size_t>{args.x_size}; auto y_sizes = std::vector<size_t>{args.y_size}; auto a_sizes = std::vector<size_t>{args.a_size}; @@ -296,6 +299,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgDilationH) { dilation_hs = tester.kDilationSizes; } if (option == kArgDilationW) { dilation_ws = tester.kDilationSizes; } if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; } + if (option == kArgNumKernels) { num_kernelss = tester.kNumKernels; } if (option == kArgXOffset) { x_sizes = tester.kVecSizes; } if (option == kArgYOffset) { y_sizes = tester.kVecSizes; } @@ -350,8 +354,10 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na for (auto &dilation_h: dilation_hs) { r_args.dilation_h = dilation_h; for (auto &dilation_w: dilation_ws) { r_args.dilation_w = dilation_w; for (auto &batch_count: batch_counts) { r_args.batch_count = batch_count; - C::SetSizes(r_args, tester.queue_); - regular_test_vector.push_back(r_args); + for (auto &num_kernels: num_kernelss) { r_args.num_kernels = num_kernels; + C::SetSizes(r_args, tester.queue_); + regular_test_vector.push_back(r_args); + } } } } |