summaryrefslogtreecommitdiff
path: root/test/test_utilities.cpp
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-11-21 22:05:08 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-11-21 22:05:08 +0100
commit8c9ecd97366980200a58a4b8cd77bd7f8b859abc (patch)
tree6f94e48ac0736aed2c9d594b5e4e10199fa97e52 /test/test_utilities.cpp
parent606990af6f7297528dcc44f67ce777e1ba56d2d0 (diff)
Implemented first version of reading JSON files from disk in the client to override parameters
Diffstat (limited to 'test/test_utilities.cpp')
-rw-r--r--test/test_utilities.cpp75
1 files changed, 73 insertions, 2 deletions
diff --git a/test/test_utilities.cpp b/test/test_utilities.cpp
index 84f8894f..b7aef0a0 100644
--- a/test/test_utilities.cpp
+++ b/test/test_utilities.cpp
@@ -11,10 +11,11 @@
//
// =================================================================================================
-#include "test/test_utilities.hpp"
-
#include <string>
#include <vector>
+#include <cctype>
+
+#include "test/test_utilities.hpp"
namespace clblast {
// =================================================================================================
@@ -113,4 +114,74 @@ void FloatToHalfBuffer(std::vector<half>& result, const std::vector<float>& sour
#endif
// =================================================================================================
+
+void OverrideParametersFromJSONFiles(const cl_device_id device, const Precision precision) {
+ const auto json_file_name = std::getenv("CLBLAST_JSON_FILE_OVERRIDE");
+ if (json_file_name == nullptr) { return; }
+ const auto json_file_name_string = std::string{json_file_name};
+ OverrideParametersFromJSONFile(json_file_name_string, device, precision);
+}
+
+void OverrideParametersFromJSONFile(const std::string& file_name,
+ const cl_device_id device, const Precision precision) {
+
+ std::ifstream json_file(file_name);
+ if (!json_file) { return; }
+
+ fprintf(stdout, "* Reading override-parameters from '%s'\n", file_name.c_str());
+ std::string line;
+ auto kernel_name = std::string{};
+ while (std::getline(json_file, line)) {
+ const auto line_split = split(line, ':');
+ if (line_split.size() != 2) { continue; }
+
+ // Retrieves the kernel name
+ if (line_split[0] == " \"kernel_family\"") {
+ const auto value_split = split(line_split[1], '\"');
+ if (value_split.size() != 3) { break; }
+ kernel_name = value_split[1];
+ kernel_name[0] = toupper(kernel_name[0]); // because of a tuner - database naming mismatch
+ }
+
+ // Retrieves the best-parameters and sets the override
+ if (line_split[0] == " \"best_parameters\"" && kernel_name != "") {
+ const auto value_split = split(line_split[1], '\"');
+ if (value_split.size() != 3) { break; }
+ const auto config_split = split(value_split[1], ' ');
+ if (config_split.size() == 0) { break; }
+
+ // Creates the list of parameters
+ fprintf(stdout, "* Found parameters for kernel '%s': { ", kernel_name.c_str());
+ std::unordered_map<std::string,size_t> parameters;
+ for (const auto config : config_split) {
+ const auto params_split = split(config, '=');
+ if (params_split.size() != 2) { break; }
+ const auto parameter_name = params_split[0];
+ if (parameter_name != "PRECISION") {
+ const auto parameter_value = static_cast<size_t>(std::stoi(params_split[1].c_str()));
+ printf("%s=%zu ", parameter_name.c_str(), parameter_value);
+ parameters[parameter_name] = parameter_value;
+ }
+ }
+ fprintf(stdout, "}\n");
+
+ // Applies the parameter override
+ const auto status = OverrideParameters(device, kernel_name, precision, parameters);
+ if (status != StatusCode::kSuccess) { break; }
+
+ // Ends this function (success)
+ fprintf(stdout, "* Applying parameter override successfully\n");
+ fprintf(stdout, "\n");
+ json_file.close();
+ return;
+ }
+ }
+
+ // Ends this function (failure)
+ fprintf(stdout, "* Failed to extract parameters from the file, continuing regularly\n");
+ fprintf(stdout, "\n");
+ json_file.close();
+}
+
+// =================================================================================================
} // namespace clblast