diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-01-11 23:09:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-11 23:09:48 +0100 |
commit | 6d52eb2956c294a46b43627c468813cf64f60b99 (patch) | |
tree | caa7d889091d004d066772083a157a38f9e6d967 /test | |
parent | 9b084d04093fdbfb22ee4790c6b3db5c55cd2719 (diff) | |
parent | 90e8e55acb4f4c059317fb7e812389a9704b2cb3 (diff) |
Merge pull request #240 from CNugteren/retrieve_tuning_parameters
Retrieve tuning parameters
Diffstat (limited to 'test')
-rw-r--r-- | test/correctness/misc/override_parameters.cpp | 3 | ||||
-rw-r--r-- | test/correctness/misc/retrieve_parameters.cpp | 84 |
2 files changed, 86 insertions, 1 deletions
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp index 05f40f57..1edfb2ba 100644 --- a/test/correctness/misc/override_parameters.cpp +++ b/test/correctness/misc/override_parameters.cpp @@ -86,11 +86,12 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st auto device_a = Buffer<T>(context, host_a.size()); auto device_b = Buffer<T>(context, host_b.size()); auto device_c = Buffer<T>(context, host_c.size()); + auto device_temp = Buffer<T>(context, args.m * args.n * args.k); // just to be safe device_a.Write(queue, host_a.size(), host_a); device_b.Write(queue, host_b.size(), host_b); device_c.Write(queue, host_c.size(), host_c); auto dummy = Buffer<T>(context, 1); - auto buffers = Buffers<T>{dummy, dummy, device_a, device_b, device_c, dummy, dummy}; + auto buffers = Buffers<T>{dummy, dummy, device_a, device_b, device_c, device_temp, dummy}; // Loops over the valid combinations: run before and run afterwards fprintf(stdout, "* Testing OverrideParameters for '%s'\n", routine_name.c_str()); diff --git a/test/correctness/misc/retrieve_parameters.cpp b/test/correctness/misc/retrieve_parameters.cpp new file mode 100644 index 00000000..568dab0d --- /dev/null +++ b/test/correctness/misc/retrieve_parameters.cpp @@ -0,0 +1,84 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This file contains the tests for the RetrieveParameters function +// +// ================================================================================================= + +#include <string> +#include <vector> +#include <unordered_map> +#include <iostream> + +#include "utilities/utilities.hpp" + +namespace clblast { +// ================================================================================================= + +template <typename T> +size_t RunRetrieveParametersTests(int argc, char *argv[], const bool silent, const std::string &routine_name) { + auto arguments = RetrieveCommandLineArguments(argc, argv); + auto errors = size_t{0}; + auto passed = size_t{0}; + + // Retrieves the arguments + auto help = std::string{"Options given/available:\n"}; + const auto platform_id = GetArgument(arguments, help, kArgPlatform, ConvertArgument(std::getenv("CLBLAST_PLATFORM"), size_t{0})); + const auto device_id = GetArgument(arguments, help, kArgDevice, ConvertArgument(std::getenv("CLBLAST_DEVICE"), size_t{0})); + auto args = Arguments<T>{}; + + // Determines the test settings + const auto kernel_name = std::string{"Xgemm"}; + const auto expected_parameters = std::vector<std::string>{ + "KWG", "KWI", "MDIMA", "MDIMC", "MWG", "NDIMB", "NDIMC", "NWG", "SA", "SB", "STRM", "STRN", "VWM", "VWN" + }; + const auto expected_max_value = size_t{16384}; + + // Prints the help message (command-line arguments) + if (!silent) { fprintf(stdout, "\n* %s\n", help.c_str()); } + + // Initializes OpenCL + const auto platform = Platform(platform_id); + const auto device = Device(platform, device_id); + + // Retrieves the parameters + fprintf(stdout, "* Testing RetrieveParameters for '%s'\n", routine_name.c_str()); + auto parameters = std::unordered_map<std::string,size_t>(); + const auto status = RetrieveParameters(device(), kernel_name, PrecisionValue<T>(), parameters); + if (status != StatusCode::kSuccess) { errors++; } + + // Verifies the parameters + for (const auto &expected_parameter : expected_parameters) { + if (parameters.find(expected_parameter) != parameters.end()) { + const auto value = parameters[expected_parameter]; + if (value < expected_max_value) { passed++; } else { errors++; } + //std::cout << expected_parameter << " = " << value << std::endl; + } + else { errors++; } + } + + // Prints and returns the statistics + std::cout << " " << passed << " test(s) passed" << std::endl; + std::cout << " " << errors << " test(s) failed" << std::endl; + std::cout << std::endl; + return errors; +} + +// ================================================================================================= +} // namespace clblast + +// Main function (not within the clblast namespace) +int main(int argc, char *argv[]) { + auto errors = size_t{0}; + errors += clblast::RunRetrieveParametersTests<float>(argc, argv, false, "SGEMM"); + errors += clblast::RunRetrieveParametersTests<clblast::float2>(argc, argv, true, "CGEMM"); + if (errors > 0) { return 1; } else { return 0; } +} + +// ================================================================================================= |