diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-11-21 22:05:08 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-11-21 22:05:08 +0100 |
commit | 8c9ecd97366980200a58a4b8cd77bd7f8b859abc (patch) | |
tree | 6f94e48ac0736aed2c9d594b5e4e10199fa97e52 /test/test_utilities.cpp | |
parent | 606990af6f7297528dcc44f67ce777e1ba56d2d0 (diff) |
Implemented first version of reading JSON files from disk in the client to override parameters
Diffstat (limited to 'test/test_utilities.cpp')
-rw-r--r-- | test/test_utilities.cpp | 75 |
1 files changed, 73 insertions, 2 deletions
diff --git a/test/test_utilities.cpp b/test/test_utilities.cpp index 84f8894f..b7aef0a0 100644 --- a/test/test_utilities.cpp +++ b/test/test_utilities.cpp @@ -11,10 +11,11 @@ // // ================================================================================================= -#include "test/test_utilities.hpp" - #include <string> #include <vector> +#include <cctype> + +#include "test/test_utilities.hpp" namespace clblast { // ================================================================================================= @@ -113,4 +114,74 @@ void FloatToHalfBuffer(std::vector<half>& result, const std::vector<float>& sour #endif // ================================================================================================= + +void OverrideParametersFromJSONFiles(const cl_device_id device, const Precision precision) { + const auto json_file_name = std::getenv("CLBLAST_JSON_FILE_OVERRIDE"); + if (json_file_name == nullptr) { return; } + const auto json_file_name_string = std::string{json_file_name}; + OverrideParametersFromJSONFile(json_file_name_string, device, precision); +} + +void OverrideParametersFromJSONFile(const std::string& file_name, + const cl_device_id device, const Precision precision) { + + std::ifstream json_file(file_name); + if (!json_file) { return; } + + fprintf(stdout, "* Reading override-parameters from '%s'\n", file_name.c_str()); + std::string line; + auto kernel_name = std::string{}; + while (std::getline(json_file, line)) { + const auto line_split = split(line, ':'); + if (line_split.size() != 2) { continue; } + + // Retrieves the kernel name + if (line_split[0] == " \"kernel_family\"") { + const auto value_split = split(line_split[1], '\"'); + if (value_split.size() != 3) { break; } + kernel_name = value_split[1]; + kernel_name[0] = toupper(kernel_name[0]); // because of a tuner - database naming mismatch + } + + // Retrieves the best-parameters and sets the override + if (line_split[0] == " \"best_parameters\"" && kernel_name != "") { + const auto value_split = split(line_split[1], '\"'); + if (value_split.size() != 3) { break; } + const auto config_split = split(value_split[1], ' '); + if (config_split.size() == 0) { break; } + + // Creates the list of parameters + fprintf(stdout, "* Found parameters for kernel '%s': { ", kernel_name.c_str()); + std::unordered_map<std::string,size_t> parameters; + for (const auto config : config_split) { + const auto params_split = split(config, '='); + if (params_split.size() != 2) { break; } + const auto parameter_name = params_split[0]; + if (parameter_name != "PRECISION") { + const auto parameter_value = static_cast<size_t>(std::stoi(params_split[1].c_str())); + printf("%s=%zu ", parameter_name.c_str(), parameter_value); + parameters[parameter_name] = parameter_value; + } + } + fprintf(stdout, "}\n"); + + // Applies the parameter override + const auto status = OverrideParameters(device, kernel_name, precision, parameters); + if (status != StatusCode::kSuccess) { break; } + + // Ends this function (success) + fprintf(stdout, "* Applying parameter override successfully\n"); + fprintf(stdout, "\n"); + json_file.close(); + return; + } + } + + // Ends this function (failure) + fprintf(stdout, "* Failed to extract parameters from the file, continuing regularly\n"); + fprintf(stdout, "\n"); + json_file.close(); +} + +// ================================================================================================= } // namespace clblast |