summaryrefslogtreecommitdiff
path: root/test/routines
diff options
context:
space:
mode:
Diffstat (limited to 'test/routines')
-rw-r--r--test/routines/common.hpp1
-rw-r--r--test/routines/levelx/xgemmbatched.hpp9
2 files changed, 10 insertions, 0 deletions
diff --git a/test/routines/common.hpp b/test/routines/common.hpp
index 9708288a..47c8f8d7 100644
--- a/test/routines/common.hpp
+++ b/test/routines/common.hpp
@@ -18,6 +18,7 @@
#include <string>
#include "utilities/utilities.hpp"
+#include "test/test_utilities.hpp"
#ifdef CLBLAST_REF_CLBLAS
#include "test/wrapper_clblas.hpp"
diff --git a/test/routines/levelx/xgemmbatched.hpp b/test/routines/levelx/xgemmbatched.hpp
index 56823e47..704d0578 100644
--- a/test/routines/levelx/xgemmbatched.hpp
+++ b/test/routines/levelx/xgemmbatched.hpp
@@ -110,6 +110,15 @@ class TestXgemmBatched {
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
auto queue_plain = queue();
auto event = cl_event{};
+ // Relaxed requirement on ld_a and ld_b within the library, this is here to match clBLAS
+ auto a_rotated = (args.layout == Layout::kColMajor && args.a_transpose != Transpose::kNo) ||
+ (args.layout == Layout::kRowMajor && args.a_transpose == Transpose::kNo);
+ auto b_rotated = (args.layout == Layout::kColMajor && args.b_transpose != Transpose::kNo) ||
+ (args.layout == Layout::kRowMajor && args.b_transpose == Transpose::kNo);
+ auto a_one = (!a_rotated) ? args.m : args.k;
+ auto b_one = (!b_rotated) ? args.k : args.n;
+ if (args.a_ld < a_one) { return StatusCode::kInvalidLeadDimA; }
+ if (args.b_ld < b_one) { return StatusCode::kInvalidLeadDimB; }
auto status = GemmBatched(args.layout, args.a_transpose, args.b_transpose,
args.m, args.n, args.k, args.alphas.data(),
buffers.a_mat(), args.a_offsets.data(), args.a_ld,