diff options
Diffstat (limited to 'test/performance')
-rw-r--r-- | test/performance/client.cpp | 4 | ||||
-rw-r--r-- | test/performance/routines/levelx/xconvgemm.cpp | 33 |
2 files changed, 36 insertions, 1 deletions
diff --git a/test/performance/client.cpp b/test/performance/client.cpp index e2d1a6c7..377e0140 100644 --- a/test/performance/client.cpp +++ b/test/performance/client.cpp @@ -106,7 +106,7 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le if (o == kArgAlpha) { args.alpha = GetArgument(command_line_args, help, kArgAlpha, GetScalar<U>()); } if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<U>()); } - // Arguments for im2col + // Arguments for im2col and convgemm if (o == kArgChannels) { args.channels = GetArgument(command_line_args, help, kArgChannels, size_t{64}); } if (o == kArgHeight) { args.height = GetArgument(command_line_args, help, kArgHeight, size_t{64}); } if (o == kArgWidth) { args.width = GetArgument(command_line_args, help, kArgWidth, size_t{64}); } @@ -118,6 +118,7 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le if (o == kArgStrideW) { args.stride_w = GetArgument(command_line_args, help, kArgStrideW, size_t{1}); } if (o == kArgDilationH) { args.dilation_h = GetArgument(command_line_args, help, kArgDilationH, size_t{1}); } if (o == kArgDilationW) { args.dilation_w = GetArgument(command_line_args, help, kArgDilationW, size_t{1}); } + if (o == kArgNumKernels){ args.num_kernels = GetArgument(command_line_args, help, kArgNumKernels, size_t{1}); } } // These are the options common to all routines @@ -446,6 +447,7 @@ void Client<T,U>::PrintTableRow(const Arguments<U>& args, else if (o == kArgStrideW) {integers.push_back(args.stride_w); } else if (o == kArgDilationH) {integers.push_back(args.dilation_h); } else if (o == kArgDilationW) {integers.push_back(args.dilation_w); } + else if (o == kArgNumKernels){integers.push_back(args.num_kernels); } } auto strings = std::vector<std::string>{}; for (auto &o: options_) { diff --git a/test/performance/routines/levelx/xconvgemm.cpp b/test/performance/routines/levelx/xconvgemm.cpp new file mode 100644 index 00000000..e6bbe8e8 --- /dev/null +++ b/test/performance/routines/levelx/xconvgemm.cpp @@ -0,0 +1,33 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// ================================================================================================= + +#include "test/performance/client.hpp" +#include "test/routines/levelx/xconvgemm.hpp" + +// Main function (not within the clblast namespace) +int main(int argc, char *argv[]) { + const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv); + switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) { + case clblast::Precision::kHalf: + clblast::RunClient<clblast::TestXconvgemm<clblast::half>, clblast::half, clblast::half>(argc, argv); break; + case clblast::Precision::kSingle: + clblast::RunClient<clblast::TestXconvgemm<float>, float, float>(argc, argv); break; + case clblast::Precision::kDouble: + clblast::RunClient<clblast::TestXconvgemm<double>, double, double>(argc, argv); break; + case clblast::Precision::kComplexSingle: + clblast::RunClient<clblast::TestXconvgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv); break; + case clblast::Precision::kComplexDouble: + clblast::RunClient<clblast::TestXconvgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv); break; + } + return 0; +} + +// ================================================================================================= |