summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/internal/utilities.h2
-rw-r--r--src/utilities.cc8
-rw-r--r--test/correctness/testblas.h1
-rw-r--r--test/correctness/tester.cc1
-rw-r--r--test/performance/client.cc2
5 files changed, 14 insertions, 0 deletions
diff --git a/include/internal/utilities.h b/include/internal/utilities.h
index 93cd509e..60d70eae 100644
--- a/include/internal/utilities.h
+++ b/include/internal/utilities.h
@@ -46,6 +46,7 @@ constexpr auto kArgATransp = "transA";
constexpr auto kArgBTransp = "transB";
constexpr auto kArgSide = "side";
constexpr auto kArgTriangle = "triangle";
+constexpr auto kArgDiagonal = "diagonal";
constexpr auto kArgXInc = "incx";
constexpr auto kArgYInc = "incy";
constexpr auto kArgXOffset = "offx";
@@ -93,6 +94,7 @@ struct Arguments {
Transpose b_transpose = Transpose::kNo;
Side side = Side::kLeft;
Triangle triangle = Triangle::kUpper;
+ Diagonal diagonal = Diagonal::kUnit;
size_t x_inc = 1;
size_t y_inc = 1;
size_t x_offset = 0;
diff --git a/src/utilities.cc b/src/utilities.cc
index 98570088..62abbb91 100644
--- a/src/utilities.cc
+++ b/src/utilities.cc
@@ -79,6 +79,13 @@ std::string ToString(Triangle value) {
}
}
template <>
+std::string ToString(Diagonal value) {
+ switch(value) {
+ case Diagonal::kUnit: return ToString(static_cast<int>(value))+" (unit)";
+ case Diagonal::kNonUnit: return ToString(static_cast<int>(value))+" (non-unit)";
+ }
+}
+template <>
std::string ToString(Precision value) {
switch(value) {
case Precision::kHalf: return ToString(static_cast<int>(value))+" (half)";
@@ -143,6 +150,7 @@ template Layout GetArgument<Layout>(const int, char **, std::string&, const std:
template Transpose GetArgument<Transpose>(const int, char **, std::string&, const std::string&, const Transpose);
template Side GetArgument<Side>(const int, char **, std::string&, const std::string&, const Side);
template Triangle GetArgument<Triangle>(const int, char **, std::string&, const std::string&, const Triangle);
+template Diagonal GetArgument<Diagonal>(const int, char **, std::string&, const std::string&, const Diagonal);
template Precision GetArgument<Precision>(const int, char **, std::string&, const std::string&, const Precision);
// =================================================================================================
diff --git a/test/correctness/testblas.h b/test/correctness/testblas.h
index 1f92cb30..7469700d 100644
--- a/test/correctness/testblas.h
+++ b/test/correctness/testblas.h
@@ -61,6 +61,7 @@ class TestBlas: public Tester<T> {
const std::vector<Layout> kLayouts = {Layout::kRowMajor, Layout::kColMajor};
const std::vector<Triangle> kTriangles = {Triangle::kUpper, Triangle::kLower};
const std::vector<Side> kSides = {Side::kLeft, Side::kRight};
+ const std::vector<Diagonal> kDiagonals = {Diagonal::kUnit, Diagonal::kNonUnit};
static const std::vector<Transpose> kTransposes; // Data-type dependent, see .cc-file
// Shorthand for the routine-specific functions passed to the tester
diff --git a/test/correctness/tester.cc b/test/correctness/tester.cc
index 4a179718..db4ee619 100644
--- a/test/correctness/tester.cc
+++ b/test/correctness/tester.cc
@@ -137,6 +137,7 @@ void Tester<T>::TestEnd() {
if (o == kArgBTransp) { fprintf(stdout, "%s=%d ", kArgBTransp, entry.args.b_transpose);}
if (o == kArgSide) { fprintf(stdout, "%s=%d ", kArgSide, entry.args.side);}
if (o == kArgTriangle) { fprintf(stdout, "%s=%d ", kArgTriangle, entry.args.triangle);}
+ if (o == kArgDiagonal) { fprintf(stdout, "%s=%d ", kArgDiagonal, entry.args.diagonal);}
if (o == kArgXInc) { fprintf(stdout, "%s=%lu ", kArgXInc, entry.args.x_inc);}
if (o == kArgYInc) { fprintf(stdout, "%s=%lu ", kArgYInc, entry.args.y_inc);}
if (o == kArgXOffset) { fprintf(stdout, "%s=%lu ", kArgXOffset, entry.args.x_offset);}
diff --git a/test/performance/client.cc b/test/performance/client.cc
index 71471dde..fad0f3a9 100644
--- a/test/performance/client.cc
+++ b/test/performance/client.cc
@@ -58,6 +58,7 @@ Arguments<T> Client<T>::ParseArguments(int argc, char *argv[], const GetMetric d
if (o == kArgBTransp) { args.b_transpose = GetArgument(argc, argv, help, kArgBTransp, Transpose::kNo); }
if (o == kArgSide) { args.side = GetArgument(argc, argv, help, kArgSide, Side::kLeft); }
if (o == kArgTriangle) { args.triangle = GetArgument(argc, argv, help, kArgTriangle, Triangle::kUpper); }
+ if (o == kArgDiagonal) { args.diagonal = GetArgument(argc, argv, help, kArgDiagonal, Diagonal::kUnit); }
// Vector arguments
if (o == kArgXInc) { args.x_inc = GetArgument(argc, argv, help, kArgXInc, size_t{1}); }
@@ -224,6 +225,7 @@ void Client<T>::PrintTableRow(const Arguments<T>& args, const double ms_clblast,
else if (o == kArgTriangle) { integers.push_back(static_cast<size_t>(args.triangle)); }
else if (o == kArgATransp) { integers.push_back(static_cast<size_t>(args.a_transpose)); }
else if (o == kArgBTransp) { integers.push_back(static_cast<size_t>(args.b_transpose)); }
+ else if (o == kArgDiagonal) { integers.push_back(static_cast<size_t>(args.diagonal)); }
else if (o == kArgXInc) { integers.push_back(args.x_inc); }
else if (o == kArgYInc) { integers.push_back(args.y_inc); }
else if (o == kArgXOffset) { integers.push_back(args.x_offset); }