// ================================================================================================= // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- // width of 100 characters per line. // // Author(s): // Cedric Nugteren // // This file implements the common (test) utility functions. // // ================================================================================================= #include "internal/utilities.h" #include #include #include #include #include namespace clblast { // ================================================================================================= // Implements the string conversion using std::to_string if possible template std::string ToString(T value) { return std::to_string(value); } template std::string ToString(int value); template std::string ToString(size_t value); template std::string ToString(float value); template std::string ToString(double value); // If not possible directly: special cases for complex data-types template <> std::string ToString(float2 value) { std::ostringstream real, imag; real << std::setprecision(2) << value.real(); imag << std::setprecision(2) << value.imag(); return real.str()+"+"+imag.str()+"i"; } template <> std::string ToString(double2 value) { std::ostringstream real, imag; real << std::setprecision(2) << value.real(); imag << std::setprecision(2) << value.imag(); return real.str()+"+"+imag.str()+"i"; } // If not possible directly: special cases for CLBlast data-types template <> std::string ToString(Layout value) { switch(value) { case Layout::kRowMajor: return ToString(static_cast(value))+" (row-major)"; case Layout::kColMajor: return ToString(static_cast(value))+" (col-major)"; } } template <> std::string ToString(Transpose value) { switch(value) { case Transpose::kNo: return ToString(static_cast(value))+" (regular)"; case Transpose::kYes: return ToString(static_cast(value))+" (transposed)"; case Transpose::kConjugate: return ToString(static_cast(value))+" (conjugate)"; } } template <> std::string ToString(Side value) { switch(value) { case Side::kLeft: return ToString(static_cast(value))+" (left)"; case Side::kRight: return ToString(static_cast(value))+" (right)"; } } template <> std::string ToString(Triangle value) { switch(value) { case Triangle::kUpper: return ToString(static_cast(value))+" (upper)"; case Triangle::kLower: return ToString(static_cast(value))+" (lower)"; } } template <> std::string ToString(Diagonal value) { switch(value) { case Diagonal::kUnit: return ToString(static_cast(value))+" (unit)"; case Diagonal::kNonUnit: return ToString(static_cast(value))+" (non-unit)"; } } template <> std::string ToString(Precision value) { switch(value) { case Precision::kHalf: return ToString(static_cast(value))+" (half)"; case Precision::kSingle: return ToString(static_cast(value))+" (single)"; case Precision::kDouble: return ToString(static_cast(value))+" (double)"; case Precision::kComplexSingle: return ToString(static_cast(value))+" (complex-single)"; case Precision::kComplexDouble: return ToString(static_cast(value))+" (complex-double)"; } } // ================================================================================================= // Helper for the below function to convert the argument to the value type. Adds specialization for // complex data-types. Note that complex arguments are accepted as regular values and are copied to // both the real and imaginary parts. template T ConvertArgument(const char* value) { return static_cast(std::stoi(value)); } template <> float ConvertArgument(const char* value) { return static_cast(std::stod(value)); } template <> double ConvertArgument(const char* value) { return static_cast(std::stod(value)); } template <> float2 ConvertArgument(const char* value) { auto val = static_cast(std::stod(value)); return float2{val, val}; } template <> double2 ConvertArgument(const char* value) { auto val = static_cast(std::stod(value)); return double2{val, val}; } // This function matches patterns in the form of "-option value" or "--option value". It returns a // default value in case the option is not found in the argument string. template T GetArgument(const int argc, char *argv[], std::string &help, const std::string &option, const T default_value) { // Parses the argument. Note that this supports both the given option (e.g. -device) and one with // an extra dash in front (e.g. --device). auto return_value = static_cast(default_value); for (int c=0; c(argv[c]); break; } } // Updates the help message and returns help += " -"+option+" "+ToString(return_value)+" "; help += (return_value == default_value) ? "[=default]\n" : "\n"; return return_value; } // Compiles the above function template int GetArgument(const int, char **, std::string&, const std::string&, const int); template size_t GetArgument(const int, char **, std::string&, const std::string&, const size_t); template float GetArgument(const int, char **, std::string&, const std::string&, const float); template double GetArgument(const int, char **, std::string&, const std::string&, const double); template float2 GetArgument(const int, char **, std::string&, const std::string&, const float2); template double2 GetArgument(const int, char **, std::string&, const std::string&, const double2); template Layout GetArgument(const int, char **, std::string&, const std::string&, const Layout); template Transpose GetArgument(const int, char **, std::string&, const std::string&, const Transpose); template Side GetArgument(const int, char **, std::string&, const std::string&, const Side); template Triangle GetArgument(const int, char **, std::string&, const std::string&, const Triangle); template Diagonal GetArgument(const int, char **, std::string&, const std::string&, const Diagonal); template Precision GetArgument(const int, char **, std::string&, const std::string&, const Precision); // ================================================================================================= // Returns only the precision argument Precision GetPrecision(const int argc, char *argv[], const Precision default_precision) { auto dummy = std::string{}; return GetArgument(argc, argv, dummy, kArgPrecision, default_precision); } // ================================================================================================= // Checks whether an argument is given. Returns true or false. bool CheckArgument(const int argc, char *argv[], std::string &help, const std::string &option) { // Parses the argument. Note that this supports both the given option (e.g. -device) and one with // an extra dash in front (e.g. --device). auto return_value = false; for (int c=0; c(std::chrono::system_clock::now().time_since_epoch().count()); } // Create a random number generator and populates a vector with samples from a random distribution template void PopulateVector(std::vector &vector) { auto lower_limit = static_cast(kTestDataLowerLimit); auto upper_limit = static_cast(kTestDataUpperLimit); std::mt19937 mt(GetRandomSeed()); std::uniform_real_distribution dist(lower_limit, upper_limit); for (auto &element: vector) { element = dist(mt); } } template void PopulateVector(std::vector&); template void PopulateVector(std::vector&); // Specialized versions of the above for complex data-types template <> void PopulateVector(std::vector &vector) { auto lower_limit = static_cast(kTestDataLowerLimit); auto upper_limit = static_cast(kTestDataUpperLimit); std::mt19937 mt(GetRandomSeed()); std::uniform_real_distribution dist(lower_limit, upper_limit); for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); } } template <> void PopulateVector(std::vector &vector) { auto lower_limit = static_cast(kTestDataLowerLimit); auto upper_limit = static_cast(kTestDataUpperLimit); std::mt19937 mt(GetRandomSeed()); std::uniform_real_distribution dist(lower_limit, upper_limit); for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); } } // ================================================================================================= // Returns a scalar with a default value template T GetScalar() { return static_cast(2.0); } template float GetScalar(); template double GetScalar(); // Specialized versions of the above for complex data-types template <> float2 GetScalar() { return {2.0f, 0.5f}; } template <> double2 GetScalar() { return {2.0, 0.5}; } // ================================================================================================= // Rounding functions performing ceiling and division operations size_t CeilDiv(const size_t x, const size_t y) { return 1 + ((x - 1) / y); } size_t Ceil(const size_t x, const size_t y) { return CeilDiv(x,y)*y; } // Helper function to determine whether or not 'a' is a multiple of 'b' bool IsMultiple(const size_t a, const size_t b) { return ((a/b)*b == a) ? true : false; }; // ================================================================================================= // Convert the precision enum (as integer) into bytes size_t GetBytes(const Precision precision) { switch(precision) { case Precision::kHalf: return 2; case Precision::kSingle: return 4; case Precision::kDouble: return 8; case Precision::kComplexSingle: return 8; case Precision::kComplexDouble: return 16; } } // ================================================================================================= // Returns false is this precision is not supported by the device template <> bool PrecisionSupported(const Device &) { return true; } template <> bool PrecisionSupported(const Device &) { return true; } template <> bool PrecisionSupported(const Device &device) { auto extensions = device.Capabilities(); return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; } template <> bool PrecisionSupported(const Device &device) { auto extensions = device.Capabilities(); return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; } // ================================================================================================= } // namespace clblast