summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-10-01 20:36:56 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-10-01 20:36:56 +0200
commit74fd6767b93b03fc62462f44854215c4c320babe (patch)
treef44dfee8b73093f776f24ad1db9aca013eb9dc8d
parent6b226028d5a65adc24e8b474b88b1e5d6bfe6015 (diff)
GEMM tests now test both the in-direct and the direct kernels seperately
-rw-r--r--test/correctness/misc/override_parameters.cpp2
-rw-r--r--test/correctness/routines/level3/xgemm.cpp20
-rw-r--r--test/performance/routines/level3/xgemm.cpp10
-rw-r--r--test/routines/level3/xgemm.hpp9
4 files changed, 29 insertions, 12 deletions
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp
index 95ece98c..05f40f57 100644
--- a/test/correctness/misc/override_parameters.cpp
+++ b/test/correctness/misc/override_parameters.cpp
@@ -28,7 +28,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st
auto arguments = RetrieveCommandLineArguments(argc, argv);
auto errors = size_t{0};
auto passed = size_t{0};
- auto example_routine = TestXgemm<T>();
+ auto example_routine = TestXgemm<0, T>();
constexpr auto kSeed = 42; // fixed seed for reproducibility
// Determines the test settings
diff --git a/test/correctness/routines/level3/xgemm.cpp b/test/correctness/routines/level3/xgemm.cpp
index 5de73554..bdf57b36 100644
--- a/test/correctness/routines/level3/xgemm.cpp
+++ b/test/correctness/routines/level3/xgemm.cpp
@@ -15,11 +15,21 @@
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
auto errors = size_t{0};
- errors += clblast::RunTests<clblast::TestXgemm<float>, float, float>(argc, argv, false, "SGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<double>, double, double>(argc, argv, true, "DGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
+
+ // Tests GEMM based on the 'in-direct' kernel
+ errors += clblast::RunTests<clblast::TestXgemm<1, float>, float, float>(argc, argv, false, "SGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, double>, double, double>(argc, argv, true, "DGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
+
+ // Tests GEMM based on the 'direct' kernel
+ errors += clblast::RunTests<clblast::TestXgemm<2, float>, float, float>(argc, argv, true, "SGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, double>, double, double>(argc, argv, true, "DGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
+
if (errors > 0) { return 1; } else { return 0; }
}
diff --git a/test/performance/routines/level3/xgemm.cpp b/test/performance/routines/level3/xgemm.cpp
index 5b3426f5..0b67b4d3 100644
--- a/test/performance/routines/level3/xgemm.cpp
+++ b/test/performance/routines/level3/xgemm.cpp
@@ -17,15 +17,15 @@ int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) {
case clblast::Precision::kHalf:
- clblast::RunClient<clblast::TestXgemm<clblast::half>, clblast::half, clblast::half>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, clblast::half>, clblast::half, clblast::half>(argc, argv); break;
case clblast::Precision::kSingle:
- clblast::RunClient<clblast::TestXgemm<float>, float, float>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
- clblast::RunClient<clblast::TestXgemm<double>, double, double>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
- clblast::RunClient<clblast::TestXgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, clblast::float2>, clblast::float2, clblast::float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
- clblast::RunClient<clblast::TestXgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv); break;
+ clblast::RunClient<clblast::TestXgemm<0, clblast::double2>, clblast::double2, clblast::double2>(argc, argv); break;
}
return 0;
}
diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp
index 7e0ead6d..1c430c1c 100644
--- a/test/routines/level3/xgemm.hpp
+++ b/test/routines/level3/xgemm.hpp
@@ -22,7 +22,7 @@ namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
-template <typename T>
+template <int V, typename T> // 'V' is the version of the kernel (0 for default, 1 for 'in-direct', 2 for 'direct')
class TestXgemm {
public:
@@ -83,6 +83,13 @@ class TestXgemm {
// Describes how to run the CLBlast routine
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
+ if (V != 0) {
+ const auto device = queue.GetDevice();
+ const auto switch_threshold = (V == 1) ? size_t{0} : size_t{1024 * 1024 * 1024}; // large enough for tests
+ const auto override_status = OverrideParameters(device(), "KernelSelection", PrecisionValue<T>(),
+ {{"XGEMM_MIN_INDIRECT_SIZE", switch_threshold}});
+ if (override_status != StatusCode::kSuccess) { return override_status; }
+ }
auto queue_plain = queue();
auto event = cl_event{};
auto status = Gemm(args.layout, args.a_transpose, args.b_transpose,