diff options
Diffstat (limited to 'src/utilities/utilities.cpp')
-rw-r--r-- | src/utilities/utilities.cpp | 69 |
1 files changed, 47 insertions, 22 deletions
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index 24456252..5e445bb9 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -158,6 +158,27 @@ std::string ToString(StatusCode value) { // ================================================================================================= +// Retrieves the command-line arguments in a C++ fashion. Also adds command-line arguments from +// pre-defined environmental variables +std::vector<std::string> RetrieveCommandLineArguments(int argc, char *argv[]) { + + // Regular command-line arguments + auto command_line_args = std::vector<std::string>(); + for (auto i=0; i<argc; ++i) { + command_line_args.push_back(std::string{argv[i]}); + } + + // Extra CLBlast arguments + const auto extra_args = ConvertArgument(std::getenv("CLBLAST_ARGUMENTS"), std::string{""}); + std::stringstream extra_args_stream; + extra_args_stream.str(extra_args); + std::string extra_arg; + while (std::getline(extra_args_stream, extra_arg, ' ')) { + command_line_args.push_back(extra_arg); + } + return command_line_args; +} + // Helper for the below function to convert the argument to the value type. Adds specialization for // complex data-types. Note that complex arguments are accepted as regular values and are copied to // both the real and imaginary parts. @@ -167,6 +188,9 @@ T ConvertArgument(const char* value) { } template size_t ConvertArgument(const char* value); +template <> std::string ConvertArgument(const char* value) { + return std::string{value}; +} template <> half ConvertArgument(const char* value) { return FloatToHalf(static_cast<float>(std::stod(value))); } @@ -193,21 +217,22 @@ T ConvertArgument(const char* value, T default_value) { return default_value; } template size_t ConvertArgument(const char* value, size_t default_value); +template std::string ConvertArgument(const char* value, std::string default_value); // This function matches patterns in the form of "-option value" or "--option value". It returns a // default value in case the option is not found in the argument string. template <typename T> -T GetArgument(const int argc, char **argv, std::string &help, +T GetArgument(const std::vector<std::string> &arguments, std::string &help, const std::string &option, const T default_value) { // Parses the argument. Note that this supports both the given option (e.g. -device) and one with // an extra dash in front (e.g. --device). auto return_value = static_cast<T>(default_value); - for (int c=0; c<argc; ++c) { - auto item = std::string{argv[c]}; + for (auto c=size_t{0}; c<arguments.size(); ++c) { + auto item = arguments[c]; if (item.compare("-"+option) == 0 || item.compare("--"+option) == 0) { ++c; - return_value = ConvertArgument<T>(argv[c]); + return_value = ConvertArgument<T>(arguments[c].c_str()); break; } } @@ -219,39 +244,39 @@ T GetArgument(const int argc, char **argv, std::string &help, } // Compiles the above function -template int GetArgument<int>(const int, char **, std::string&, const std::string&, const int); -template size_t GetArgument<size_t>(const int, char **, std::string&, const std::string&, const size_t); -template half GetArgument<half>(const int, char **, std::string&, const std::string&, const half); -template float GetArgument<float>(const int, char **, std::string&, const std::string&, const float); -template double GetArgument<double>(const int, char **, std::string&, const std::string&, const double); -template float2 GetArgument<float2>(const int, char **, std::string&, const std::string&, const float2); -template double2 GetArgument<double2>(const int, char **, std::string&, const std::string&, const double2); -template Layout GetArgument<Layout>(const int, char **, std::string&, const std::string&, const Layout); -template Transpose GetArgument<Transpose>(const int, char **, std::string&, const std::string&, const Transpose); -template Side GetArgument<Side>(const int, char **, std::string&, const std::string&, const Side); -template Triangle GetArgument<Triangle>(const int, char **, std::string&, const std::string&, const Triangle); -template Diagonal GetArgument<Diagonal>(const int, char **, std::string&, const std::string&, const Diagonal); -template Precision GetArgument<Precision>(const int, char **, std::string&, const std::string&, const Precision); +template int GetArgument<int>(const std::vector<std::string>&, std::string&, const std::string&, const int); +template size_t GetArgument<size_t>(const std::vector<std::string>&, std::string&, const std::string&, const size_t); +template half GetArgument<half>(const std::vector<std::string>&, std::string&, const std::string&, const half); +template float GetArgument<float>(const std::vector<std::string>&, std::string&, const std::string&, const float); +template double GetArgument<double>(const std::vector<std::string>&, std::string&, const std::string&, const double); +template float2 GetArgument<float2>(const std::vector<std::string>&, std::string&, const std::string&, const float2); +template double2 GetArgument<double2>(const std::vector<std::string>&, std::string&, const std::string&, const double2); +template Layout GetArgument<Layout>(const std::vector<std::string>&, std::string&, const std::string&, const Layout); +template Transpose GetArgument<Transpose>(const std::vector<std::string>&, std::string&, const std::string&, const Transpose); +template Side GetArgument<Side>(const std::vector<std::string>&, std::string&, const std::string&, const Side); +template Triangle GetArgument<Triangle>(const std::vector<std::string>&, std::string&, const std::string&, const Triangle); +template Diagonal GetArgument<Diagonal>(const std::vector<std::string>&, std::string&, const std::string&, const Diagonal); +template Precision GetArgument<Precision>(const std::vector<std::string>&, std::string&, const std::string&, const Precision); // ================================================================================================= // Returns only the precision argument -Precision GetPrecision(const int argc, char *argv[], const Precision default_precision) { +Precision GetPrecision(const std::vector<std::string> &arguments, const Precision default_precision) { auto dummy = std::string{}; - return GetArgument(argc, argv, dummy, kArgPrecision, default_precision); + return GetArgument(arguments, dummy, kArgPrecision, default_precision); } // ================================================================================================= // Checks whether an argument is given. Returns true or false. -bool CheckArgument(const int argc, char *argv[], std::string &help, +bool CheckArgument(const std::vector<std::string> &arguments, std::string &help, const std::string &option) { // Parses the argument. Note that this supports both the given option (e.g. -device) and one with // an extra dash in front (e.g. --device). auto return_value = false; - for (int c=0; c<argc; ++c) { - auto item = std::string{argv[c]}; + for (auto c=size_t{0}; c<arguments.size(); ++c) { + auto item = arguments[c]; if (item.compare("-"+option) == 0 || item.compare("--"+option) == 0) { ++c; return_value = true; |