summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-02-14 21:09:00 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-02-14 21:09:00 +0100
commitbdc57221bd0279bcdb4f024df54f08a2fe1bb8d4 (patch)
tree45a8949bd2281c7ea8ca99b4b5489f6ac1bd7648 /test
parentcdb3bb7166bc75842ff95e14915bff881297fc62 (diff)
Added simple tests for the OverrideParameters function
Diffstat (limited to 'test')
-rw-r--r--test/correctness/misc/override_parameters.cpp135
1 files changed, 135 insertions, 0 deletions
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp
new file mode 100644
index 00000000..54d18cfa
--- /dev/null
+++ b/test/correctness/misc/override_parameters.cpp
@@ -0,0 +1,135 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file contains the tests for the OverrideParameters function
+//
+// =================================================================================================
+
+#include "utilities/utilities.hpp"
+#include "test/routines/level3/xgemm.hpp"
+
+#include <unordered_map>
+
+namespace clblast {
+// =================================================================================================
+
+template <typename T>
+size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::string &routine_name) {
+ auto arguments = RetrieveCommandLineArguments(argc, argv);
+ auto errors = size_t{0};
+ auto passed = size_t{0};
+ auto example_routine = TestXgemm<T>();
+ constexpr auto kSeed = 42; // fixed seed for reproducibility
+
+ // Determines the test settings
+ const auto kernel_name = std::string{"Xgemm"};
+ const auto precision = PrecisionValue<T>();
+ const auto valid_settings = std::vector<std::unordered_map<std::string,size_t>>{
+ { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
+ { {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
+ };
+ const auto invalid_settings = std::vector<std::unordered_map<std::string,size_t>>{
+ { {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} },
+ };
+
+ // Retrieves the arguments
+ auto help = std::string{"Options given/available:\n"};
+ const auto platform_id = GetArgument(arguments, help, kArgPlatform, ConvertArgument(std::getenv("CLBLAST_PLATFORM"), size_t{0}));
+ const auto device_id = GetArgument(arguments, help, kArgDevice, ConvertArgument(std::getenv("CLBLAST_DEVICE"), size_t{0}));
+ auto args = Arguments<T>{};
+ args.m = GetArgument(arguments, help, kArgM, size_t{512});
+ args.n = GetArgument(arguments, help, kArgN, size_t{512});
+ args.k = GetArgument(arguments, help, kArgK, size_t{512});
+ args.a_ld = GetArgument(arguments, help, kArgALeadDim, args.k);
+ args.b_ld = GetArgument(arguments, help, kArgBLeadDim, args.n);
+ args.c_ld = GetArgument(arguments, help, kArgCLeadDim, args.n);
+ args.a_offset = GetArgument(arguments, help, kArgAOffset, size_t{0});
+ args.b_offset = GetArgument(arguments, help, kArgBOffset, size_t{0});
+ args.c_offset = GetArgument(arguments, help, kArgCOffset, size_t{0});
+ args.layout = GetArgument(arguments, help, kArgLayout, Layout::kRowMajor);
+ args.a_transpose = GetArgument(arguments, help, kArgATransp, Transpose::kNo);
+ args.b_transpose = GetArgument(arguments, help, kArgBTransp, Transpose::kNo);
+ args.alpha = GetArgument(arguments, help, kArgAlpha, GetScalar<T>());
+ args.beta = GetArgument(arguments, help, kArgBeta, GetScalar<T>());
+
+ // Prints the help message (command-line arguments)
+ if (!silent) { fprintf(stdout, "\n* %s\n", help.c_str()); }
+
+ // Initializes OpenCL
+ const auto platform = Platform(platform_id);
+ const auto device = Device(platform, device_id);
+ const auto context = Context(device);
+ auto queue = Queue(context, device);
+
+ // Populate host matrices with some example data
+ auto host_a = std::vector<T>(args.m * args.k);
+ auto host_b = std::vector<T>(args.n * args.k);
+ auto host_c = std::vector<T>(args.m * args.n);
+ PopulateVector(host_a, kSeed);
+ PopulateVector(host_b, kSeed);
+ PopulateVector(host_c, kSeed);
+
+ // Copy the matrices to the device
+ auto device_a = Buffer<T>(context, host_a.size());
+ auto device_b = Buffer<T>(context, host_b.size());
+ auto device_c = Buffer<T>(context, host_c.size());
+ device_a.Write(queue, host_a.size(), host_a);
+ device_b.Write(queue, host_b.size(), host_b);
+ device_c.Write(queue, host_c.size(), host_c);
+ auto dummy = Buffer<T>(context, 1);
+ auto buffers = Buffers<T>{dummy, dummy, device_a, device_b, device_c, dummy, dummy};
+
+ // Loops over the valid combinations: run before and run afterwards
+ fprintf(stdout, "* Testing OverrideParameters for '%s'\n", routine_name.c_str());
+ for (const auto &override_setting : valid_settings) {
+ const auto status_before = example_routine.RunRoutine(args, buffers, queue);
+ if (status_before != StatusCode::kSuccess) { errors++; continue; }
+
+ // Overrides the parameters
+ const auto status = OverrideParameters(device(), kernel_name, precision, override_setting);
+ if (status != StatusCode::kSuccess) { errors++; continue; } // error shouldn't occur
+
+ const auto status_after = example_routine.RunRoutine(args, buffers, queue);
+ if (status_after != StatusCode::kSuccess) { errors++; continue; }
+ passed++;
+ }
+
+ // Loops over the invalid combinations: run before and run afterwards
+ for (const auto &override_setting : invalid_settings) {
+ const auto status_before = example_routine.RunRoutine(args, buffers, queue);
+ if (status_before != StatusCode::kSuccess) { errors++; continue; }
+
+ // Overrides the parameters
+ const auto status = OverrideParameters(device(), kernel_name, precision, override_setting);
+ if (status == StatusCode::kSuccess) { errors++; continue; } // error should occur
+
+ const auto status_after = example_routine.RunRoutine(args, buffers, queue);
+ if (status_after != StatusCode::kSuccess) { errors++; continue; }
+ passed++;
+ }
+
+ // Prints and returns the statistics
+ fprintf(stdout, " %zu test(s) passed\n", passed);
+ fprintf(stdout, " %zu test(s) failed\n", errors);
+ fprintf(stdout, "\n");
+ return errors;
+}
+
+// =================================================================================================
+} // namespace clblast
+
+// Main function (not within the clblast namespace)
+int main(int argc, char *argv[]) {
+ auto errors = size_t{0};
+ errors += clblast::RunOverrideTests<float>(argc, argv, false, "SGEMM");
+ errors += clblast::RunOverrideTests<double>(argc, argv, true, "DGEMM");
+ if (errors > 0) { return 1; } else { return 0; }
+}
+
+// =================================================================================================