diff options
author | JishinMaster <francois.turban@gmail.com> | 2021-03-07 21:44:20 +0100 |
---|---|---|
committer | JishinMaster <francois.turban@gmail.com> | 2021-03-13 21:48:04 +0100 |
commit | aec45ea63755a7a9414d7c6a81d570200e948806 (patch) | |
tree | c2ec91cfcc73fe29de643afc07a68ede3bd39c36 /test/routines/levelx | |
parent | ce44c3adb57fc8d464a41d3db8103dc3fe0e86c4 (diff) |
set the correct flop count for xgemm
Diffstat (limited to 'test/routines/levelx')
-rw-r--r-- | test/routines/levelx/xgemmbatched.hpp | 8 | ||||
-rw-r--r-- | test/routines/levelx/xgemmstridedbatched.hpp | 8 |
2 files changed, 14 insertions, 2 deletions
diff --git a/test/routines/levelx/xgemmbatched.hpp b/test/routines/levelx/xgemmbatched.hpp index b787ca27..c299cfdb 100644 --- a/test/routines/levelx/xgemmbatched.hpp +++ b/test/routines/levelx/xgemmbatched.hpp @@ -217,7 +217,13 @@ class TestXgemmBatched { // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { - return args.batch_count * (2 * args.m * args.n * args.k); + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return args.batch_count * args.m * args.n * (8 * args.k - 2); + } else { + // scalar flops + return args.batch_count * args.m * args.n * (2 * args.k - 1); + } } static size_t GetBytes(const Arguments<T> &args) { return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T); diff --git a/test/routines/levelx/xgemmstridedbatched.hpp b/test/routines/levelx/xgemmstridedbatched.hpp index ddb32997..0522a569 100644 --- a/test/routines/levelx/xgemmstridedbatched.hpp +++ b/test/routines/levelx/xgemmstridedbatched.hpp @@ -204,7 +204,13 @@ public: // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { - return args.batch_count * (2 * args.m * args.n * args.k); + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return args.batch_count * args.m * args.n * (8 * args.k - 2); + } else { + // scalar flops + return args.batch_count * args.m * args.n * (2 * args.k - 1); + } } static size_t GetBytes(const Arguments<T> &args) { return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T); |