summaryrefslogtreecommitdiff
path: root/test/correctness
diff options
context:
space:
mode:
Diffstat (limited to 'test/correctness')
-rw-r--r--test/correctness/misc/override_parameters.cpp3
-rw-r--r--test/correctness/routines/level3/xgemm.cpp15
-rw-r--r--test/correctness/testblas.cpp44
-rw-r--r--test/correctness/tester.hpp2
4 files changed, 28 insertions, 36 deletions
diff --git a/test/correctness/misc/override_parameters.cpp b/test/correctness/misc/override_parameters.cpp
index 535d9286..05f40f57 100644
--- a/test/correctness/misc/override_parameters.cpp
+++ b/test/correctness/misc/override_parameters.cpp
@@ -28,7 +28,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st
auto arguments = RetrieveCommandLineArguments(argc, argv);
auto errors = size_t{0};
auto passed = size_t{0};
- auto example_routine = TestXgemm<T>();
+ auto example_routine = TestXgemm<0, T>();
constexpr auto kSeed = 42; // fixed seed for reproducibility
// Determines the test settings
@@ -37,6 +37,7 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st
const auto valid_settings = std::vector<std::unordered_map<std::string,size_t>>{
{ {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
{ {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
+ { {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
};
const auto invalid_settings = std::vector<std::unordered_map<std::string,size_t>>{
{ {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} },
diff --git a/test/correctness/routines/level3/xgemm.cpp b/test/correctness/routines/level3/xgemm.cpp
index 5de73554..351e538b 100644
--- a/test/correctness/routines/level3/xgemm.cpp
+++ b/test/correctness/routines/level3/xgemm.cpp
@@ -15,11 +15,16 @@
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
auto errors = size_t{0};
- errors += clblast::RunTests<clblast::TestXgemm<float>, float, float>(argc, argv, false, "SGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<double>, double, double>(argc, argv, true, "DGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
- errors += clblast::RunTests<clblast::TestXgemm<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, float>, float, float>(argc, argv, false, "SGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, double>, double, double>(argc, argv, true, "DGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<1, clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, float>, float, float>(argc, argv, true, "SGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, double>, double, double>(argc, argv, true, "DGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMM");
+ errors += clblast::RunTests<clblast::TestXgemm<2, clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMM");
if (errors > 0) { return 1; } else { return 0; }
}
diff --git a/test/correctness/testblas.cpp b/test/correctness/testblas.cpp
index 659131c5..aa4b4785 100644
--- a/test/correctness/testblas.cpp
+++ b/test/correctness/testblas.cpp
@@ -241,36 +241,22 @@ void TestBlas<T,U>::TestInvalid(std::vector<Arguments<U>> &test_vector, const st
std::cout << std::flush;
}
- // Creates the OpenCL buffers. Note: we are not using the C++ version since we explicitly
+ // Creates the buffers. Note: we are not using the cxpp11.h C++ version since we explicitly
// want to be able to create invalid buffers (no error checking here).
- auto x1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.x_size*sizeof(T), nullptr,nullptr);
- auto y1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.y_size*sizeof(T), nullptr,nullptr);
- auto a1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.a_size*sizeof(T), nullptr,nullptr);
- auto b1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.b_size*sizeof(T), nullptr,nullptr);
- auto c1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.c_size*sizeof(T), nullptr,nullptr);
- auto ap1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.ap_size*sizeof(T), nullptr,nullptr);
- auto d1 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.scalar_size*sizeof(T), nullptr,nullptr);
- auto x_vec1 = Buffer<T>(x1);
- auto y_vec1 = Buffer<T>(y1);
- auto a_mat1 = Buffer<T>(a1);
- auto b_mat1 = Buffer<T>(b1);
- auto c_mat1 = Buffer<T>(c1);
- auto ap_mat1 = Buffer<T>(ap1);
- auto scalar1 = Buffer<T>(d1);
- auto x2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.x_size*sizeof(T), nullptr,nullptr);
- auto y2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.y_size*sizeof(T), nullptr,nullptr);
- auto a2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.a_size*sizeof(T), nullptr,nullptr);
- auto b2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.b_size*sizeof(T), nullptr,nullptr);
- auto c2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.c_size*sizeof(T), nullptr,nullptr);
- auto ap2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.ap_size*sizeof(T), nullptr,nullptr);
- auto d2 = clCreateBuffer(context_(), CL_MEM_READ_WRITE, args.scalar_size*sizeof(T), nullptr,nullptr);
- auto x_vec2 = Buffer<T>(x2);
- auto y_vec2 = Buffer<T>(y2);
- auto a_mat2 = Buffer<T>(a2);
- auto b_mat2 = Buffer<T>(b2);
- auto c_mat2 = Buffer<T>(c2);
- auto ap_mat2 = Buffer<T>(ap2);
- auto scalar2 = Buffer<T>(d2);
+ auto x_vec1 = CreateInvalidBuffer<T>(context_, args.x_size);
+ auto y_vec1 = CreateInvalidBuffer<T>(context_, args.y_size);
+ auto a_mat1 = CreateInvalidBuffer<T>(context_, args.a_size);
+ auto b_mat1 = CreateInvalidBuffer<T>(context_, args.b_size);
+ auto c_mat1 = CreateInvalidBuffer<T>(context_, args.c_size);
+ auto ap_mat1 = CreateInvalidBuffer<T>(context_, args.ap_size);
+ auto scalar1 = CreateInvalidBuffer<T>(context_, args.scalar_size);
+ auto x_vec2 = CreateInvalidBuffer<T>(context_, args.x_size);
+ auto y_vec2 = CreateInvalidBuffer<T>(context_, args.y_size);
+ auto a_mat2 = CreateInvalidBuffer<T>(context_, args.a_size);
+ auto b_mat2 = CreateInvalidBuffer<T>(context_, args.b_size);
+ auto c_mat2 = CreateInvalidBuffer<T>(context_, args.c_size);
+ auto ap_mat2 = CreateInvalidBuffer<T>(context_, args.ap_size);
+ auto scalar2 = CreateInvalidBuffer<T>(context_, args.scalar_size);
auto buffers1 = Buffers<T>{x_vec1, y_vec1, a_mat1, b_mat1, c_mat1, ap_mat1, scalar1};
auto buffers2 = Buffers<T>{x_vec2, y_vec2, a_mat2, b_mat2, c_mat2, ap_mat2, scalar2};
diff --git a/test/correctness/tester.hpp b/test/correctness/tester.hpp
index caf03787..640f870a 100644
--- a/test/correctness/tester.hpp
+++ b/test/correctness/tester.hpp
@@ -22,13 +22,13 @@
#include <vector>
#include <memory>
+#include "utilities/utilities.hpp"
#include "test/test_utilities.hpp"
// The libraries
#ifdef CLBLAST_REF_CLBLAS
#include <clBLAS.h>
#endif
-#include "clblast.h"
namespace clblast {
// =================================================================================================