diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/routines/level3/xgemm.hpp | 8 | ||||
-rw-r--r-- | test/routines/level3/xsymm.hpp | 9 | ||||
-rw-r--r-- | test/routines/level3/xsyrk.hpp | 8 | ||||
-rw-r--r-- | test/routines/level3/xtrmm.hpp | 8 | ||||
-rw-r--r-- | test/routines/level3/xtrsm.hpp | 8 | ||||
-rw-r--r-- | test/routines/levelx/xgemmbatched.hpp | 8 | ||||
-rw-r--r-- | test/routines/levelx/xgemmstridedbatched.hpp | 8 |
7 files changed, 50 insertions, 7 deletions
diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp index 4cfa9c83..7a473a4f 100644 --- a/test/routines/level3/xgemm.hpp +++ b/test/routines/level3/xgemm.hpp @@ -193,7 +193,13 @@ class TestXgemm { // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { - return 2 * args.m * args.n * args.k; + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return args.m * args.n * (8 * args.k - 2); + } else { + // scalar flops + return args.m * args.n * (2 * args.k - 1); + } } static size_t GetBytes(const Arguments<T> &args) { return (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T); diff --git a/test/routines/level3/xsymm.hpp b/test/routines/level3/xsymm.hpp index 837e45d8..0c26c5c0 100644 --- a/test/routines/level3/xsymm.hpp +++ b/test/routines/level3/xsymm.hpp @@ -169,7 +169,14 @@ class TestXsymm { // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { - return 2 * args.m * args.n * args.m; + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return 8 * args.m * args.n * args.m; + } else { + // scalar flops + return 2 * args.m * args.n * args.m; + } + } static size_t GetBytes(const Arguments<T> &args) { return (args.m*args.m + args.m*args.n + 2*args.m*args.n) * sizeof(T); diff --git a/test/routines/level3/xsyrk.hpp b/test/routines/level3/xsyrk.hpp index 23dcf12f..81180740 100644 --- a/test/routines/level3/xsyrk.hpp +++ b/test/routines/level3/xsyrk.hpp @@ -153,7 +153,13 @@ class TestXsyrk { // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { - return args.n * args.n * args.k; + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return 4 * args.n * args.n * args.k; + } else { + // scalar flops + return args.n * args.n * args.k; + } } static size_t GetBytes(const Arguments<T> &args) { return (args.n*args.k + args.n*args.n) * sizeof(T); diff --git a/test/routines/level3/xtrmm.hpp b/test/routines/level3/xtrmm.hpp index 51377a16..f39db5cd 100644 --- a/test/routines/level3/xtrmm.hpp +++ b/test/routines/level3/xtrmm.hpp @@ -162,7 +162,13 @@ class TestXtrmm { // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { auto k = (args.side == Side::kLeft) ? args.m : args.n; - return args.m * args.n * k; + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return 4 * args.m * args.n * k; + } else { + // scalar flops + return args.m * args.n * k; + } } static size_t GetBytes(const Arguments<T> &args) { auto k = (args.side == Side::kLeft) ? args.m : args.n; diff --git a/test/routines/level3/xtrsm.hpp b/test/routines/level3/xtrsm.hpp index 66c8f415..9560b116 100644 --- a/test/routines/level3/xtrsm.hpp +++ b/test/routines/level3/xtrsm.hpp @@ -173,7 +173,13 @@ class TestXtrsm { // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { auto k = (args.side == Side::kLeft) ? args.m : args.n; - return args.m * args.n * k; + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return 4 * args.m * args.n * k; + } else { + // scalar flops + return args.m * args.n * k; + } } static size_t GetBytes(const Arguments<T> &args) { auto k = (args.side == Side::kLeft) ? args.m : args.n; diff --git a/test/routines/levelx/xgemmbatched.hpp b/test/routines/levelx/xgemmbatched.hpp index b787ca27..c299cfdb 100644 --- a/test/routines/levelx/xgemmbatched.hpp +++ b/test/routines/levelx/xgemmbatched.hpp @@ -217,7 +217,13 @@ class TestXgemmBatched { // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { - return args.batch_count * (2 * args.m * args.n * args.k); + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return args.batch_count * args.m * args.n * (8 * args.k - 2); + } else { + // scalar flops + return args.batch_count * args.m * args.n * (2 * args.k - 1); + } } static size_t GetBytes(const Arguments<T> &args) { return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T); diff --git a/test/routines/levelx/xgemmstridedbatched.hpp b/test/routines/levelx/xgemmstridedbatched.hpp index ddb32997..0522a569 100644 --- a/test/routines/levelx/xgemmstridedbatched.hpp +++ b/test/routines/levelx/xgemmstridedbatched.hpp @@ -204,7 +204,13 @@ public: // Describes how to compute performance metrics static size_t GetFlops(const Arguments<T> &args) { - return args.batch_count * (2 * args.m * args.n * args.k); + if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) { + // complex flops + return args.batch_count * args.m * args.n * (8 * args.k - 2); + } else { + // scalar flops + return args.batch_count * args.m * args.n * (2 * args.k - 1); + } } static size_t GetBytes(const Arguments<T> &args) { return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T); |