summaryrefslogtreecommitdiff
path: root/test/routines/levelx
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-07-12 21:53:39 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-07-12 21:53:39 +0200
commitf77b48692b5dbc9f13f5f93a8242ea546a39236e (patch)
tree6642aa51003134c550abbd37dad57c7c29c091f6 /test/routines/levelx
parentf2477f663672fd37301d6e2ce4646519f71d5cce (diff)
Relaxed requirement on a_ld and b_ld for batched GEMM
Diffstat (limited to 'test/routines/levelx')
-rw-r--r--test/routines/levelx/xgemmbatched.hpp9
1 files changed, 9 insertions, 0 deletions
diff --git a/test/routines/levelx/xgemmbatched.hpp b/test/routines/levelx/xgemmbatched.hpp
index 56823e47..704d0578 100644
--- a/test/routines/levelx/xgemmbatched.hpp
+++ b/test/routines/levelx/xgemmbatched.hpp
@@ -110,6 +110,15 @@ class TestXgemmBatched {
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
auto queue_plain = queue();
auto event = cl_event{};
+ // Relaxed requirement on ld_a and ld_b within the library, this is here to match clBLAS
+ auto a_rotated = (args.layout == Layout::kColMajor && args.a_transpose != Transpose::kNo) ||
+ (args.layout == Layout::kRowMajor && args.a_transpose == Transpose::kNo);
+ auto b_rotated = (args.layout == Layout::kColMajor && args.b_transpose != Transpose::kNo) ||
+ (args.layout == Layout::kRowMajor && args.b_transpose == Transpose::kNo);
+ auto a_one = (!a_rotated) ? args.m : args.k;
+ auto b_one = (!b_rotated) ? args.k : args.n;
+ if (args.a_ld < a_one) { return StatusCode::kInvalidLeadDimA; }
+ if (args.b_ld < b_one) { return StatusCode::kInvalidLeadDimB; }
auto status = GemmBatched(args.layout, args.a_transpose, args.b_transpose,
args.m, args.n, args.k, args.alphas.data(),
buffers.a_mat(), args.a_offsets.data(), args.a_ld,