summaryrefslogtreecommitdiff
path: root/src/tuning/xgemm.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/tuning/xgemm.cc')
-rw-r--r--src/tuning/xgemm.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/tuning/xgemm.cc b/src/tuning/xgemm.cc
index 2b4ff456..d309b830 100644
--- a/src/tuning/xgemm.cc
+++ b/src/tuning/xgemm.cc
@@ -121,11 +121,13 @@ class TuneXgemm {
std::vector<T> &, std::vector<T> &,
std::vector<T> &a_mat, std::vector<T> &b_mat, std::vector<T> &c_mat,
std::vector<T> &) {
+ auto alpha_buffer = std::vector<T>{args.alpha};
+ auto beta_buffer = std::vector<T>{args.beta};
tuner.AddArgumentScalar(static_cast<int>(args.m));
tuner.AddArgumentScalar(static_cast<int>(args.n));
tuner.AddArgumentScalar(static_cast<int>(args.k));
- tuner.AddArgumentScalar(args.alpha);
- tuner.AddArgumentScalar(args.beta);
+ tuner.AddArgumentInput(alpha_buffer);
+ tuner.AddArgumentInput(beta_buffer);
tuner.AddArgumentInput(a_mat);
tuner.AddArgumentInput(b_mat);
tuner.AddArgumentOutput(c_mat);
@@ -148,7 +150,7 @@ using double2 = clblast::double2;
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
- case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode");
+ case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXgemm<half>, half>(argc, argv); break;
case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXgemm<float>, float>(argc, argv); break;
case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXgemm<double>, double>(argc, argv); break;
case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXgemm<float2>, float2>(argc, argv); break;