diff options
Diffstat (limited to 'src/tuning')
-rw-r--r-- | src/tuning/copy.cc | 2 | ||||
-rw-r--r-- | src/tuning/pad.cc | 24 | ||||
-rw-r--r-- | src/tuning/padtranspose.cc | 2 | ||||
-rw-r--r-- | src/tuning/transpose.cc | 2 | ||||
-rw-r--r-- | src/tuning/xgemm.cc | 8 |
5 files changed, 20 insertions, 18 deletions
diff --git a/src/tuning/copy.cc b/src/tuning/copy.cc index e2837e60..09cdecf1 100644 --- a/src/tuning/copy.cc +++ b/src/tuning/copy.cc @@ -107,7 +107,7 @@ using double2 = clblast::double2; // Main function (not within the clblast namespace) int main(int argc, char *argv[]) { switch(clblast::GetPrecision(argc, argv)) { - case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); + case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneCopy<half>, half>(argc, argv); break; case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneCopy<float>, float>(argc, argv); break; case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneCopy<double>, double>(argc, argv); break; case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneCopy<float2>, float2>(argc, argv); break; diff --git a/src/tuning/pad.cc b/src/tuning/pad.cc index 72729422..075688db 100644 --- a/src/tuning/pad.cc +++ b/src/tuning/pad.cc @@ -85,17 +85,17 @@ class TunePad { std::vector<T> &, std::vector<T> &, std::vector<T> &a_mat, std::vector<T> &b_mat, std::vector<T> &, std::vector<T> &) { - tuner.AddArgumentScalar(static_cast<int>(args.m)); - tuner.AddArgumentScalar(static_cast<int>(args.n)); - tuner.AddArgumentScalar(static_cast<int>(args.m)); - tuner.AddArgumentScalar(0); - tuner.AddArgumentInput(a_mat); - tuner.AddArgumentScalar(static_cast<int>(args.m)); - tuner.AddArgumentScalar(static_cast<int>(args.n)); - tuner.AddArgumentScalar(static_cast<int>(args.m)); - tuner.AddArgumentScalar(0); - tuner.AddArgumentOutput(b_mat); - tuner.AddArgumentScalar(0); + tuner.AddArgumentScalar(static_cast<int>(args.m)); + tuner.AddArgumentScalar(static_cast<int>(args.n)); + tuner.AddArgumentScalar(static_cast<int>(args.m)); + tuner.AddArgumentScalar(0); + tuner.AddArgumentInput(a_mat); + tuner.AddArgumentScalar(static_cast<int>(args.m)); + tuner.AddArgumentScalar(static_cast<int>(args.n)); + tuner.AddArgumentScalar(static_cast<int>(args.m)); + tuner.AddArgumentScalar(0); + tuner.AddArgumentOutput(b_mat); + tuner.AddArgumentScalar(0); } // Describes how to compute the performance metrics @@ -115,7 +115,7 @@ using double2 = clblast::double2; // Main function (not within the clblast namespace) int main(int argc, char *argv[]) { switch(clblast::GetPrecision(argc, argv)) { - case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); + case clblast::Precision::kHalf: clblast::Tuner<clblast::TunePad<half>, half>(argc, argv); break; case clblast::Precision::kSingle: clblast::Tuner<clblast::TunePad<float>, float>(argc, argv); break; case clblast::Precision::kDouble: clblast::Tuner<clblast::TunePad<double>, double>(argc, argv); break; case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TunePad<float2>, float2>(argc, argv); break; diff --git a/src/tuning/padtranspose.cc b/src/tuning/padtranspose.cc index 5edd89e0..a970f982 100644 --- a/src/tuning/padtranspose.cc +++ b/src/tuning/padtranspose.cc @@ -119,7 +119,7 @@ using double2 = clblast::double2; // Main function (not within the clblast namespace) int main(int argc, char *argv[]) { switch(clblast::GetPrecision(argc, argv)) { - case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); + case clblast::Precision::kHalf: clblast::Tuner<clblast::TunePadTranspose<half>, half>(argc, argv); break; case clblast::Precision::kSingle: clblast::Tuner<clblast::TunePadTranspose<float>, float>(argc, argv); break; case clblast::Precision::kDouble: clblast::Tuner<clblast::TunePadTranspose<double>, double>(argc, argv); break; case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TunePadTranspose<float2>, float2>(argc, argv); break; diff --git a/src/tuning/transpose.cc b/src/tuning/transpose.cc index 113e0a81..d217a3df 100644 --- a/src/tuning/transpose.cc +++ b/src/tuning/transpose.cc @@ -112,7 +112,7 @@ using double2 = clblast::double2; // Main function (not within the clblast namespace) int main(int argc, char *argv[]) { switch(clblast::GetPrecision(argc, argv)) { - case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); + case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneTranspose<half>, half>(argc, argv); break; case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneTranspose<float>, float>(argc, argv); break; case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneTranspose<double>, double>(argc, argv); break; case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneTranspose<float2>, float2>(argc, argv); break; diff --git a/src/tuning/xgemm.cc b/src/tuning/xgemm.cc index 2b4ff456..d309b830 100644 --- a/src/tuning/xgemm.cc +++ b/src/tuning/xgemm.cc @@ -121,11 +121,13 @@ class TuneXgemm { std::vector<T> &, std::vector<T> &, std::vector<T> &a_mat, std::vector<T> &b_mat, std::vector<T> &c_mat, std::vector<T> &) { + auto alpha_buffer = std::vector<T>{args.alpha}; + auto beta_buffer = std::vector<T>{args.beta}; tuner.AddArgumentScalar(static_cast<int>(args.m)); tuner.AddArgumentScalar(static_cast<int>(args.n)); tuner.AddArgumentScalar(static_cast<int>(args.k)); - tuner.AddArgumentScalar(args.alpha); - tuner.AddArgumentScalar(args.beta); + tuner.AddArgumentInput(alpha_buffer); + tuner.AddArgumentInput(beta_buffer); tuner.AddArgumentInput(a_mat); tuner.AddArgumentInput(b_mat); tuner.AddArgumentOutput(c_mat); @@ -148,7 +150,7 @@ using double2 = clblast::double2; // Main function (not within the clblast namespace) int main(int argc, char *argv[]) { switch(clblast::GetPrecision(argc, argv)) { - case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); + case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXgemm<half>, half>(argc, argv); break; case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXgemm<float>, float>(argc, argv); break; case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXgemm<double>, double>(argc, argv); break; case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXgemm<float2>, float2>(argc, argv); break; |