summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJishinMaster <francois.turban@gmail.com>2021-03-07 21:44:20 +0100
committerJishinMaster <francois.turban@gmail.com>2021-03-13 21:48:04 +0100
commitaec45ea63755a7a9414d7c6a81d570200e948806 (patch)
treec2ec91cfcc73fe29de643afc07a68ede3bd39c36
parentce44c3adb57fc8d464a41d3db8103dc3fe0e86c4 (diff)
set the correct flop count for xgemm
-rw-r--r--src/tuning/kernels/xgemm.hpp8
-rw-r--r--test/routines/level3/xgemm.hpp8
-rw-r--r--test/routines/level3/xsymm.hpp9
-rw-r--r--test/routines/level3/xsyrk.hpp8
-rw-r--r--test/routines/level3/xtrmm.hpp8
-rw-r--r--test/routines/level3/xtrsm.hpp8
-rw-r--r--test/routines/levelx/xgemmbatched.hpp8
-rw-r--r--test/routines/levelx/xgemmstridedbatched.hpp8
8 files changed, 57 insertions, 8 deletions
diff --git a/src/tuning/kernels/xgemm.hpp b/src/tuning/kernels/xgemm.hpp
index fa1bb6ec..a2ef50c3 100644
--- a/src/tuning/kernels/xgemm.hpp
+++ b/src/tuning/kernels/xgemm.hpp
@@ -159,7 +159,13 @@ TunerSettings XgemmGetTunerSettings(const int V, const Arguments<T> &args) {
}
// Describes how to compute the performance metrics
- settings.metric_amount = 2 * args.m * args.n * args.k;
+ if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
+ // complex flops
+ settings.metric_amount = args.m * args.n * (8 * args.k - 2);
+ } else {
+ // scalar flops
+ settings.metric_amount = args.m * args.n * (2 * args.k - 1);
+ }
settings.performance_unit = "GFLOPS";
return settings;
diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp
index 4cfa9c83..7a473a4f 100644
--- a/test/routines/level3/xgemm.hpp
+++ b/test/routines/level3/xgemm.hpp
@@ -193,7 +193,13 @@ class TestXgemm {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
- return 2 * args.m * args.n * args.k;
+ if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
+ // complex flops
+ return args.m * args.n * (8 * args.k - 2);
+ } else {
+ // scalar flops
+ return args.m * args.n * (2 * args.k - 1);
+ }
}
static size_t GetBytes(const Arguments<T> &args) {
return (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
diff --git a/test/routines/level3/xsymm.hpp b/test/routines/level3/xsymm.hpp
index 837e45d8..0c26c5c0 100644
--- a/test/routines/level3/xsymm.hpp
+++ b/test/routines/level3/xsymm.hpp
@@ -169,7 +169,14 @@ class TestXsymm {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
- return 2 * args.m * args.n * args.m;
+ if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
+ // complex flops
+ return 8 * args.m * args.n * args.m;
+ } else {
+ // scalar flops
+ return 2 * args.m * args.n * args.m;
+ }
+
}
static size_t GetBytes(const Arguments<T> &args) {
return (args.m*args.m + args.m*args.n + 2*args.m*args.n) * sizeof(T);
diff --git a/test/routines/level3/xsyrk.hpp b/test/routines/level3/xsyrk.hpp
index 23dcf12f..81180740 100644
--- a/test/routines/level3/xsyrk.hpp
+++ b/test/routines/level3/xsyrk.hpp
@@ -153,7 +153,13 @@ class TestXsyrk {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
- return args.n * args.n * args.k;
+ if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
+ // complex flops
+ return 4 * args.n * args.n * args.k;
+ } else {
+ // scalar flops
+ return args.n * args.n * args.k;
+ }
}
static size_t GetBytes(const Arguments<T> &args) {
return (args.n*args.k + args.n*args.n) * sizeof(T);
diff --git a/test/routines/level3/xtrmm.hpp b/test/routines/level3/xtrmm.hpp
index 51377a16..f39db5cd 100644
--- a/test/routines/level3/xtrmm.hpp
+++ b/test/routines/level3/xtrmm.hpp
@@ -162,7 +162,13 @@ class TestXtrmm {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n;
- return args.m * args.n * k;
+ if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
+ // complex flops
+ return 4 * args.m * args.n * k;
+ } else {
+ // scalar flops
+ return args.m * args.n * k;
+ }
}
static size_t GetBytes(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n;
diff --git a/test/routines/level3/xtrsm.hpp b/test/routines/level3/xtrsm.hpp
index 66c8f415..9560b116 100644
--- a/test/routines/level3/xtrsm.hpp
+++ b/test/routines/level3/xtrsm.hpp
@@ -173,7 +173,13 @@ class TestXtrsm {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n;
- return args.m * args.n * k;
+ if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
+ // complex flops
+ return 4 * args.m * args.n * k;
+ } else {
+ // scalar flops
+ return args.m * args.n * k;
+ }
}
static size_t GetBytes(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n;
diff --git a/test/routines/levelx/xgemmbatched.hpp b/test/routines/levelx/xgemmbatched.hpp
index b787ca27..c299cfdb 100644
--- a/test/routines/levelx/xgemmbatched.hpp
+++ b/test/routines/levelx/xgemmbatched.hpp
@@ -217,7 +217,13 @@ class TestXgemmBatched {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
- return args.batch_count * (2 * args.m * args.n * args.k);
+ if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
+ // complex flops
+ return args.batch_count * args.m * args.n * (8 * args.k - 2);
+ } else {
+ // scalar flops
+ return args.batch_count * args.m * args.n * (2 * args.k - 1);
+ }
}
static size_t GetBytes(const Arguments<T> &args) {
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
diff --git a/test/routines/levelx/xgemmstridedbatched.hpp b/test/routines/levelx/xgemmstridedbatched.hpp
index ddb32997..0522a569 100644
--- a/test/routines/levelx/xgemmstridedbatched.hpp
+++ b/test/routines/levelx/xgemmstridedbatched.hpp
@@ -204,7 +204,13 @@ public:
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
- return args.batch_count * (2 * args.m * args.n * args.k);
+ if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
+ // complex flops
+ return args.batch_count * args.m * args.n * (8 * args.k - 2);
+ } else {
+ // scalar flops
+ return args.batch_count * args.m * args.n * (2 * args.k - 1);
+ }
}
static size_t GetBytes(const Arguments<T> &args) {
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);