From f77b48692b5dbc9f13f5f93a8242ea546a39236e Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Wed, 12 Jul 2017 21:53:39 +0200 Subject: Relaxed requirement on a_ld and b_ld for batched GEMM --- test/routines/levelx/xgemmbatched.hpp | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'test/routines/levelx') diff --git a/test/routines/levelx/xgemmbatched.hpp b/test/routines/levelx/xgemmbatched.hpp index 56823e47..704d0578 100644 --- a/test/routines/levelx/xgemmbatched.hpp +++ b/test/routines/levelx/xgemmbatched.hpp @@ -110,6 +110,15 @@ class TestXgemmBatched { static StatusCode RunRoutine(const Arguments &args, Buffers &buffers, Queue &queue) { auto queue_plain = queue(); auto event = cl_event{}; + // Relaxed requirement on ld_a and ld_b within the library, this is here to match clBLAS + auto a_rotated = (args.layout == Layout::kColMajor && args.a_transpose != Transpose::kNo) || + (args.layout == Layout::kRowMajor && args.a_transpose == Transpose::kNo); + auto b_rotated = (args.layout == Layout::kColMajor && args.b_transpose != Transpose::kNo) || + (args.layout == Layout::kRowMajor && args.b_transpose == Transpose::kNo); + auto a_one = (!a_rotated) ? args.m : args.k; + auto b_one = (!b_rotated) ? args.k : args.n; + if (args.a_ld < a_one) { return StatusCode::kInvalidLeadDimA; } + if (args.b_ld < b_one) { return StatusCode::kInvalidLeadDimB; } auto status = GemmBatched(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, args.alphas.data(), buffers.a_mat(), args.a_offsets.data(), args.a_ld, -- cgit v1.2.3