diff options
-rw-r--r-- | include/internal/utilities.h | 2 | ||||
-rw-r--r-- | src/utilities.cc | 8 | ||||
-rw-r--r-- | test/correctness/testblas.h | 1 | ||||
-rw-r--r-- | test/correctness/tester.cc | 1 | ||||
-rw-r--r-- | test/performance/client.cc | 2 |
5 files changed, 14 insertions, 0 deletions
diff --git a/include/internal/utilities.h b/include/internal/utilities.h index 93cd509e..60d70eae 100644 --- a/include/internal/utilities.h +++ b/include/internal/utilities.h @@ -46,6 +46,7 @@ constexpr auto kArgATransp = "transA"; constexpr auto kArgBTransp = "transB"; constexpr auto kArgSide = "side"; constexpr auto kArgTriangle = "triangle"; +constexpr auto kArgDiagonal = "diagonal"; constexpr auto kArgXInc = "incx"; constexpr auto kArgYInc = "incy"; constexpr auto kArgXOffset = "offx"; @@ -93,6 +94,7 @@ struct Arguments { Transpose b_transpose = Transpose::kNo; Side side = Side::kLeft; Triangle triangle = Triangle::kUpper; + Diagonal diagonal = Diagonal::kUnit; size_t x_inc = 1; size_t y_inc = 1; size_t x_offset = 0; diff --git a/src/utilities.cc b/src/utilities.cc index 98570088..62abbb91 100644 --- a/src/utilities.cc +++ b/src/utilities.cc @@ -79,6 +79,13 @@ std::string ToString(Triangle value) { } } template <> +std::string ToString(Diagonal value) { + switch(value) { + case Diagonal::kUnit: return ToString(static_cast<int>(value))+" (unit)"; + case Diagonal::kNonUnit: return ToString(static_cast<int>(value))+" (non-unit)"; + } +} +template <> std::string ToString(Precision value) { switch(value) { case Precision::kHalf: return ToString(static_cast<int>(value))+" (half)"; @@ -143,6 +150,7 @@ template Layout GetArgument<Layout>(const int, char **, std::string&, const std: template Transpose GetArgument<Transpose>(const int, char **, std::string&, const std::string&, const Transpose); template Side GetArgument<Side>(const int, char **, std::string&, const std::string&, const Side); template Triangle GetArgument<Triangle>(const int, char **, std::string&, const std::string&, const Triangle); +template Diagonal GetArgument<Diagonal>(const int, char **, std::string&, const std::string&, const Diagonal); template Precision GetArgument<Precision>(const int, char **, std::string&, const std::string&, const Precision); // ================================================================================================= diff --git a/test/correctness/testblas.h b/test/correctness/testblas.h index 1f92cb30..7469700d 100644 --- a/test/correctness/testblas.h +++ b/test/correctness/testblas.h @@ -61,6 +61,7 @@ class TestBlas: public Tester<T> { const std::vector<Layout> kLayouts = {Layout::kRowMajor, Layout::kColMajor}; const std::vector<Triangle> kTriangles = {Triangle::kUpper, Triangle::kLower}; const std::vector<Side> kSides = {Side::kLeft, Side::kRight}; + const std::vector<Diagonal> kDiagonals = {Diagonal::kUnit, Diagonal::kNonUnit}; static const std::vector<Transpose> kTransposes; // Data-type dependent, see .cc-file // Shorthand for the routine-specific functions passed to the tester diff --git a/test/correctness/tester.cc b/test/correctness/tester.cc index 4a179718..db4ee619 100644 --- a/test/correctness/tester.cc +++ b/test/correctness/tester.cc @@ -137,6 +137,7 @@ void Tester<T>::TestEnd() { if (o == kArgBTransp) { fprintf(stdout, "%s=%d ", kArgBTransp, entry.args.b_transpose);} if (o == kArgSide) { fprintf(stdout, "%s=%d ", kArgSide, entry.args.side);} if (o == kArgTriangle) { fprintf(stdout, "%s=%d ", kArgTriangle, entry.args.triangle);} + if (o == kArgDiagonal) { fprintf(stdout, "%s=%d ", kArgDiagonal, entry.args.diagonal);} if (o == kArgXInc) { fprintf(stdout, "%s=%lu ", kArgXInc, entry.args.x_inc);} if (o == kArgYInc) { fprintf(stdout, "%s=%lu ", kArgYInc, entry.args.y_inc);} if (o == kArgXOffset) { fprintf(stdout, "%s=%lu ", kArgXOffset, entry.args.x_offset);} diff --git a/test/performance/client.cc b/test/performance/client.cc index 71471dde..fad0f3a9 100644 --- a/test/performance/client.cc +++ b/test/performance/client.cc @@ -58,6 +58,7 @@ Arguments<T> Client<T>::ParseArguments(int argc, char *argv[], const GetMetric d if (o == kArgBTransp) { args.b_transpose = GetArgument(argc, argv, help, kArgBTransp, Transpose::kNo); } if (o == kArgSide) { args.side = GetArgument(argc, argv, help, kArgSide, Side::kLeft); } if (o == kArgTriangle) { args.triangle = GetArgument(argc, argv, help, kArgTriangle, Triangle::kUpper); } + if (o == kArgDiagonal) { args.diagonal = GetArgument(argc, argv, help, kArgDiagonal, Diagonal::kUnit); } // Vector arguments if (o == kArgXInc) { args.x_inc = GetArgument(argc, argv, help, kArgXInc, size_t{1}); } @@ -224,6 +225,7 @@ void Client<T>::PrintTableRow(const Arguments<T>& args, const double ms_clblast, else if (o == kArgTriangle) { integers.push_back(static_cast<size_t>(args.triangle)); } else if (o == kArgATransp) { integers.push_back(static_cast<size_t>(args.a_transpose)); } else if (o == kArgBTransp) { integers.push_back(static_cast<size_t>(args.b_transpose)); } + else if (o == kArgDiagonal) { integers.push_back(static_cast<size_t>(args.diagonal)); } else if (o == kArgXInc) { integers.push_back(args.x_inc); } else if (o == kArgYInc) { integers.push_back(args.y_inc); } else if (o == kArgXOffset) { integers.push_back(args.x_offset); } |