diff options
Diffstat (limited to 'test/correctness')
-rw-r--r-- | test/correctness/misc/override_parameters.cpp | 3 | ||||
-rw-r--r-- | test/correctness/routines/level3/xgemm.cpp | 15 | ||||
-rw-r--r-- | test/correctness/testblas.cpp | 44 | ||||
-rw-r--r-- | test/correctness/tester.hpp | 2 |
4 files changed, 28 insertions, 36 deletions
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp index 535d9286..05f40f57 100644 --- a/test/correctness/misc/override_parameters.cpp +++ b/test/correctness/misc/override_parameters.cpp @@ -28,7 +28,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st auto arguments = RetrieveCommandLineArguments(argc, argv); auto errors = size_t{0}; auto passed = size_t{0}; - auto example_routine = TestXgemm<T>(); + auto example_routine = TestXgemm<0, T>(); constexpr auto kSeed = 42; // fixed seed for reproducibility // Determines the test settings @@ -37,6 +37,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st const auto valid_settings = std::vector<std::unordered_map<std::string,size_t>>{ { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, { {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, + { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} }, }; const auto invalid_settings = std::vector<std::unordered_map<std::string,size_t>>{ { {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} }, diff --git a/test/correctness/routines/level3/xgemm.cpp b/test/correctness/routines/level3/xgemm.cpp index 5de73554..351e538b 100644 --- a/test/correctness/routines/level3/xgemm.cpp +++ b/test/correctness/routines/level3/xgemm.cpp @@ -15,11 +15,16 @@ // Main function (not within the clblast namespace) int main(int argc, char *argv[]) { auto errors = size_t{0}; - errors += clblast::RunTests<clblast::TestXgemm<float>, float, float>(argc, argv, false, "SGEMM"); - errors += clblast::RunTests<clblast::TestXgemm<double>, double, double>(argc, argv, true, "DGEMM"); - errors += clblast::RunTests<clblast::TestXgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM"); - errors += clblast::RunTests<clblast::TestXgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM"); - errors += clblast::RunTests<clblast::TestXgemm<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<1, float>, float, float>(argc, argv, false, "SGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<1, double>, double, double>(argc, argv, true, "DGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<1, clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<1, clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<1, clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<2, float>, float, float>(argc, argv, true, "SGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<2, double>, double, double>(argc, argv, true, "DGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<2, clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<2, clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM"); + errors += clblast::RunTests<clblast::TestXgemm<2, clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM"); if (errors > 0) { return 1; } else { return 0; } } diff --git a/test/correctness/testblas.cpp b/test/correctness/testblas.cpp index 659131c5..aa4b4785 100644 --- a/test/correctness/testblas.cpp +++ b/test/correctness/testblas.cpp @@ -241,36 +241,22 @@ void TestBlas<T,U>::TestInvalid(std::vector<Arguments<U>> &test_vector, const st std::cout << std::flush; } - // Creates the OpenCL buffers. Note: we are not using the C++ version since we explicitly + // Creates the buffers. Note: we are not using the cxpp11.h C++ version since we explicitly // want to be able to create invalid buffers (no error checking here). - auto x1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.x_size*sizeof(T), nullptr,nullptr); - auto y1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.y_size*sizeof(T), nullptr,nullptr); - auto a1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.a_size*sizeof(T), nullptr,nullptr); - auto b1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.b_size*sizeof(T), nullptr,nullptr); - auto c1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.c_size*sizeof(T), nullptr,nullptr); - auto ap1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.ap_size*sizeof(T), nullptr,nullptr); - auto d1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.scalar_size*sizeof(T), nullptr,nullptr); - auto x_vec1 = Buffer<T>(x1); - auto y_vec1 = Buffer<T>(y1); - auto a_mat1 = Buffer<T>(a1); - auto b_mat1 = Buffer<T>(b1); - auto c_mat1 = Buffer<T>(c1); - auto ap_mat1 = Buffer<T>(ap1); - auto scalar1 = Buffer<T>(d1); - auto x2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.x_size*sizeof(T), nullptr,nullptr); - auto y2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.y_size*sizeof(T), nullptr,nullptr); - auto a2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.a_size*sizeof(T), nullptr,nullptr); - auto b2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.b_size*sizeof(T), nullptr,nullptr); - auto c2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.c_size*sizeof(T), nullptr,nullptr); - auto ap2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.ap_size*sizeof(T), nullptr,nullptr); - auto d2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.scalar_size*sizeof(T), nullptr,nullptr); - auto x_vec2 = Buffer<T>(x2); - auto y_vec2 = Buffer<T>(y2); - auto a_mat2 = Buffer<T>(a2); - auto b_mat2 = Buffer<T>(b2); - auto c_mat2 = Buffer<T>(c2); - auto ap_mat2 = Buffer<T>(ap2); - auto scalar2 = Buffer<T>(d2); + auto x_vec1 = CreateInvalidBuffer<T>(context_, args.x_size); + auto y_vec1 = CreateInvalidBuffer<T>(context_, args.y_size); + auto a_mat1 = CreateInvalidBuffer<T>(context_, args.a_size); + auto b_mat1 = CreateInvalidBuffer<T>(context_, args.b_size); + auto c_mat1 = CreateInvalidBuffer<T>(context_, args.c_size); + auto ap_mat1 = CreateInvalidBuffer<T>(context_, args.ap_size); + auto scalar1 = CreateInvalidBuffer<T>(context_, args.scalar_size); + auto x_vec2 = CreateInvalidBuffer<T>(context_, args.x_size); + auto y_vec2 = CreateInvalidBuffer<T>(context_, args.y_size); + auto a_mat2 = CreateInvalidBuffer<T>(context_, args.a_size); + auto b_mat2 = CreateInvalidBuffer<T>(context_, args.b_size); + auto c_mat2 = CreateInvalidBuffer<T>(context_, args.c_size); + auto ap_mat2 = CreateInvalidBuffer<T>(context_, args.ap_size); + auto scalar2 = CreateInvalidBuffer<T>(context_, args.scalar_size); auto buffers1 = Buffers<T>{x_vec1, y_vec1, a_mat1, b_mat1, c_mat1, ap_mat1, scalar1}; auto buffers2 = Buffers<T>{x_vec2, y_vec2, a_mat2, b_mat2, c_mat2, ap_mat2, scalar2}; diff --git a/test/correctness/tester.hpp b/test/correctness/tester.hpp index caf03787..640f870a 100644 --- a/test/correctness/tester.hpp +++ b/test/correctness/tester.hpp @@ -22,13 +22,13 @@ #include <vector> #include <memory> +#include "utilities/utilities.hpp" #include "test/test_utilities.hpp" // The libraries #ifdef CLBLAST_REF_CLBLAS #include <clBLAS.h> #endif -#include "clblast.h" namespace clblast { // ================================================================================================= |