summaryrefslogtreecommitdiff
path: root/test/routines/level3/xher2k.h
diff options
context:
space:
mode:
Diffstat (limited to 'test/routines/level3/xher2k.h')
-rw-r--r--test/routines/level3/xher2k.h12
1 files changed, 10 insertions, 2 deletions
diff --git a/test/routines/level3/xher2k.h b/test/routines/level3/xher2k.h
index b20ec973..a7fbfcbe 100644
--- a/test/routines/level3/xher2k.h
+++ b/test/routines/level3/xher2k.h
@@ -29,6 +29,9 @@ template <typename T, typename U>
class TestXher2k {
public:
+ // The BLAS level: 1, 2, or 3
+ static size_t BLASLevel() { return 3; }
+
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
return {kArgN, kArgK,
@@ -46,8 +49,8 @@ class TestXher2k {
return a_two * args.a_ld + args.a_offset;
}
static size_t GetSizeB(const Arguments<U> &args) {
- auto b_rotated = (args.layout == Layout::kColMajor && args.b_transpose != Transpose::kNo) ||
- (args.layout == Layout::kRowMajor && args.b_transpose == Transpose::kNo);
+ auto b_rotated = (args.layout == Layout::kColMajor && args.a_transpose != Transpose::kNo) ||
+ (args.layout == Layout::kRowMajor && args.a_transpose == Transpose::kNo);
auto b_two = (b_rotated) ? args.n : args.k;
return b_two * args.b_ld + args.b_offset;
}
@@ -67,6 +70,11 @@ class TestXher2k {
static size_t DefaultLDB(const Arguments<U> &args) { return args.k; }
static size_t DefaultLDC(const Arguments<U> &args) { return args.n; }
+ // Describes which transpose options are relevant for this routine
+ using Transposes = std::vector<Transpose>;
+ static Transposes GetATransposes(const Transposes &) { return {Transpose::kNo, Transpose::kConjugate}; }
+ static Transposes GetBTransposes(const Transposes &) { return {}; } // N/A for this routine
+
// Describes how to run the CLBlast routine
static StatusCode RunRoutine(const Arguments<U> &args, const Buffers<T> &buffers, Queue &queue) {
auto queue_plain = queue();