summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-03-22 22:42:33 +0100
committerGitHub <noreply@github.com>2018-03-22 22:42:33 +0100
commita97d8a01970c49f2b21d952e841668da3db0184d (patch)
tree143bd103186eb0d3fd9672ab8a41ebb076d1107d
parent52791bf3553bb47a50dea4ac234f7e1b09c4383c (diff)
parent9fb6550dd02c54fafbb03e20516a394d9cd63f3f (diff)
Merge pull request #269 from CNugteren/CLBlast-266-local-mem-constraint
CLBlast #266 local mem constraint
-rw-r--r--CHANGELOG1
-rw-r--r--src/tuning/configurations.cpp34
-rw-r--r--src/tuning/configurations.hpp21
-rw-r--r--src/tuning/kernels/copy_fast.cpp10
-rw-r--r--src/tuning/kernels/copy_fast.hpp4
-rw-r--r--src/tuning/kernels/copy_pad.cpp10
-rw-r--r--src/tuning/kernels/copy_pad.hpp4
-rw-r--r--src/tuning/kernels/invert.cpp10
-rw-r--r--src/tuning/kernels/invert.hpp9
-rw-r--r--src/tuning/kernels/transpose_fast.cpp10
-rw-r--r--src/tuning/kernels/transpose_fast.hpp9
-rw-r--r--src/tuning/kernels/transpose_pad.cpp10
-rw-r--r--src/tuning/kernels/transpose_pad.hpp9
-rw-r--r--src/tuning/kernels/xaxpy.cpp10
-rw-r--r--src/tuning/kernels/xaxpy.hpp4
-rw-r--r--src/tuning/kernels/xdot.cpp10
-rw-r--r--src/tuning/kernels/xdot.hpp8
-rw-r--r--src/tuning/kernels/xgemm.cpp10
-rw-r--r--src/tuning/kernels/xgemm.hpp9
-rw-r--r--src/tuning/kernels/xgemm_direct.cpp10
-rw-r--r--src/tuning/kernels/xgemm_direct.hpp9
-rw-r--r--src/tuning/kernels/xgemv.cpp10
-rw-r--r--src/tuning/kernels/xgemv.hpp17
-rw-r--r--src/tuning/kernels/xger.cpp10
-rw-r--r--src/tuning/kernels/xger.hpp4
-rw-r--r--src/tuning/tuning.cpp14
-rw-r--r--src/tuning/tuning.hpp4
-rw-r--r--src/tuning/tuning_api.cpp42
28 files changed, 220 insertions, 92 deletions
diff --git a/CHANGELOG b/CHANGELOG
index 7c22009c..5815e343 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -3,6 +3,7 @@ Development (next version)
- Added Python interface to CLBlast 'PyCLBlast'
- Added CLBlast to Ubuntu PPA and macOS Homebrew package managers
- Added an API to run the tuners programmatically without any I/O
+- Re-added a local memory size constraint to the tuners
- Updated and reorganised the CLBlast documentation
- Various minor fixes and enhancements
- Added non-BLAS level-1 routines:
diff --git a/src/tuning/configurations.cpp b/src/tuning/configurations.cpp
index 459d66b1..1fe232cf 100644
--- a/src/tuning/configurations.cpp
+++ b/src/tuning/configurations.cpp
@@ -21,11 +21,15 @@ namespace clblast {
// =================================================================================================
// Finds all configurations. It also applies the user-defined constraints within.
-std::vector<Configuration> SetConfigurations(const std::vector<Parameter> parameters,
- const Constraints& constraints) {
+std::vector<Configuration> SetConfigurations(const Device& device,
+ const std::vector<Parameter> parameters,
+ const Constraints& constraints,
+ const LocalMemSizeInfo& local_mem_size_info) {
+ const auto local_mem_max = device.LocalMemSize();
auto config = Configuration();
auto configurations = std::vector<Configuration>();
- PopulateConfigurations(parameters, 0, config, configurations, constraints);
+ PopulateConfigurations(parameters, 0, config, configurations,
+ local_mem_max, constraints, local_mem_size_info);
return configurations;
}
@@ -33,12 +37,14 @@ std::vector<Configuration> SetConfigurations(const std::vector<Parameter> parame
void PopulateConfigurations(const std::vector<Parameter> &parameters,
const size_t index, const Configuration &config,
std::vector<Configuration> &configuration,
- const Constraints& constraints) {
+ const size_t local_mem_max,
+ const Constraints& constraints,
+ const LocalMemSizeInfo& local_mem_size_info) {
// End of the chain: all parameters are considered, store the resulting configuration if it is a
// valid one according to the constraints
if (index == parameters.size()) {
- if (ValidConfiguration(config, constraints)) {
+ if (ValidConfiguration(config, local_mem_max, constraints, local_mem_size_info)) {
configuration.push_back(config);
}
return;
@@ -49,13 +55,16 @@ void PopulateConfigurations(const std::vector<Parameter> &parameters,
for (auto &value: parameter.second) {
auto config_copy = config;
config_copy[parameter.first] = value;
- PopulateConfigurations(parameters, index+1, config_copy, configuration, constraints);
+ PopulateConfigurations(parameters, index+1, config_copy, configuration,
+ local_mem_max, constraints, local_mem_size_info);
}
}
// Loops over all user-defined constraints to check whether or not the configuration is valid
bool ValidConfiguration(const Configuration &config,
- const Constraints& constraints) {
+ const size_t local_mem_max,
+ const Constraints& constraints,
+ const LocalMemSizeInfo& local_mem_size_info) {
// Iterates over all constraints
for (auto &constraint: constraints) {
@@ -72,6 +81,17 @@ bool ValidConfiguration(const Configuration &config,
}
}
+ // Finds the values of the local memory parameters
+ auto local_mem_values = std::vector<size_t>(local_mem_size_info.parameters.size());
+ for (auto i=size_t{0}; i<local_mem_size_info.parameters.size(); ++i) {
+ local_mem_values[i] = config.at(local_mem_size_info.parameters[i]);
+ }
+
+ // Checks the local memory size
+ if (local_mem_size_info.local_mem_size(local_mem_values) > local_mem_max) {
+ return false;
+ }
+
// Everything was OK: this configuration is valid
return true;
}
diff --git a/src/tuning/configurations.hpp b/src/tuning/configurations.hpp
index 74679ff6..faa5498f 100644
--- a/src/tuning/configurations.hpp
+++ b/src/tuning/configurations.hpp
@@ -37,12 +37,21 @@ struct Constraint {
};
using Constraints = std::vector<Constraint>;
+// As above, but for local memory size
+using LocalMemSizeFunction = std::function<size_t(std::vector<size_t>)>;
+struct LocalMemSizeInfo {
+ LocalMemSizeFunction local_mem_size;
+ std::vector<std::string> parameters;
+};
+
// =================================================================================================
// Initializes an empty configuration (vector of name/value pairs) and kicks-off the recursive
// function to find all configurations. It also applies the user-defined constraints within.
-std::vector<Configuration> SetConfigurations(const std::vector<Parameter> parameters,
- const Constraints& constraints);
+std::vector<Configuration> SetConfigurations(const Device& device,
+ const std::vector<Parameter> parameters,
+ const Constraints& constraints,
+ const LocalMemSizeInfo& local_mem_size_info);
// Iterates recursively over all permutations of the user-defined parameters. This code creates
// multiple chains, in which each chain selects a unique combination of values for all parameters.
@@ -51,14 +60,18 @@ std::vector<Configuration> SetConfigurations(const std::vector<Parameter> parame
void PopulateConfigurations(const std::vector<Parameter> &parameters,
const size_t index, const Configuration &config,
std::vector<Configuration> &configuration,
- const Constraints& constraints);
+ const size_t local_mem_max,
+ const Constraints& constraints,
+ const LocalMemSizeInfo& local_mem_size_info);
// Loops over all user-defined constraints to check whether or not the configuration is valid.
// Assumes initially all configurations are valid, then returns false if one of the constraints has
// not been met. Constraints consist of a user-defined function and a list of parameter names, which
// are replaced by parameter values in this function.
bool ValidConfiguration(const Configuration &config,
- const Constraints& constraints);
+ const size_t local_mem_max,
+ const Constraints& constraints,
+ const LocalMemSizeInfo& local_mem_size_info);
// Processes multipliers and dividers to obtain the final thread configuration
std::vector<size_t> SetThreadConfiguration(const Configuration& config,
diff --git a/src/tuning/kernels/copy_fast.cpp b/src/tuning/kernels/copy_fast.cpp
index 0314113c..13f7ef3c 100644
--- a/src/tuning/kernels/copy_fast.cpp
+++ b/src/tuning/kernels/copy_fast.cpp
@@ -22,11 +22,11 @@ using double2 = clblast::double2;
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<half>, clblast::CopyTestValidArguments<half>, clblast::CopySetConstraints, clblast::CopySetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<float>, clblast::CopyTestValidArguments<float>, clblast::CopySetConstraints, clblast::CopySetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<double>, clblast::CopyTestValidArguments<double>, clblast::CopySetConstraints, clblast::CopySetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<float2>, clblast::CopyTestValidArguments<float2>, clblast::CopySetConstraints, clblast::CopySetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<double2>, clblast::CopyTestValidArguments<double2>, clblast::CopySetConstraints, clblast::CopySetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<half>, clblast::CopyTestValidArguments<half>, clblast::CopySetConstraints, clblast::CopyComputeLocalMemSize<half>, clblast::CopySetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<float>, clblast::CopyTestValidArguments<float>, clblast::CopySetConstraints, clblast::CopyComputeLocalMemSize<float>, clblast::CopySetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<double>, clblast::CopyTestValidArguments<double>, clblast::CopySetConstraints, clblast::CopyComputeLocalMemSize<double>, clblast::CopySetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<float2>, clblast::CopyTestValidArguments<float2>, clblast::CopySetConstraints, clblast::CopyComputeLocalMemSize<float2>, clblast::CopySetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::CopyGetTunerDefaults, clblast::CopyGetTunerSettings<double2>, clblast::CopyTestValidArguments<double2>, clblast::CopySetConstraints, clblast::CopyComputeLocalMemSize<double2>, clblast::CopySetArguments<double2>); break;
}
return 0;
}
diff --git a/src/tuning/kernels/copy_fast.hpp b/src/tuning/kernels/copy_fast.hpp
index f9a58bc7..1c4219ae 100644
--- a/src/tuning/kernels/copy_fast.hpp
+++ b/src/tuning/kernels/copy_fast.hpp
@@ -79,6 +79,10 @@ TunerSettings CopyGetTunerSettings(const int, const Arguments<T> &args) {
template <typename T>
void CopyTestValidArguments(const int, const Arguments<T> &) { }
std::vector<Constraint> CopySetConstraints(const int) { return {}; }
+template <typename T>
+LocalMemSizeInfo CopyComputeLocalMemSize(const int) {
+ return { [] (std::vector<size_t>) -> size_t { return 0; }, {} };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/copy_pad.cpp b/src/tuning/kernels/copy_pad.cpp
index 909a71c8..ffaed6ed 100644
--- a/src/tuning/kernels/copy_pad.cpp
+++ b/src/tuning/kernels/copy_pad.cpp
@@ -22,11 +22,11 @@ using double2 = clblast::double2;
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<half>, clblast::PadTestValidArguments<half>, clblast::PadSetConstraints, clblast::PadSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<float>, clblast::PadTestValidArguments<float>, clblast::PadSetConstraints, clblast::PadSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<double>, clblast::PadTestValidArguments<double>, clblast::PadSetConstraints, clblast::PadSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<float2>, clblast::PadTestValidArguments<float2>, clblast::PadSetConstraints, clblast::PadSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<double2>, clblast::PadTestValidArguments<double2>, clblast::PadSetConstraints, clblast::PadSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<half>, clblast::PadTestValidArguments<half>, clblast::PadSetConstraints, clblast::PadComputeLocalMemSize<half>, clblast::PadSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<float>, clblast::PadTestValidArguments<float>, clblast::PadSetConstraints, clblast::PadComputeLocalMemSize<float>, clblast::PadSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<double>, clblast::PadTestValidArguments<double>, clblast::PadSetConstraints, clblast::PadComputeLocalMemSize<double>, clblast::PadSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<float2>, clblast::PadTestValidArguments<float2>, clblast::PadSetConstraints, clblast::PadComputeLocalMemSize<float2>, clblast::PadSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::PadGetTunerDefaults, clblast::PadGetTunerSettings<double2>, clblast::PadTestValidArguments<double2>, clblast::PadSetConstraints, clblast::PadComputeLocalMemSize<double2>, clblast::PadSetArguments<double2>); break;
}
return 0;
}
diff --git a/src/tuning/kernels/copy_pad.hpp b/src/tuning/kernels/copy_pad.hpp
index e612ca9e..ada1cf83 100644
--- a/src/tuning/kernels/copy_pad.hpp
+++ b/src/tuning/kernels/copy_pad.hpp
@@ -79,6 +79,10 @@ TunerSettings PadGetTunerSettings(const int, const Arguments<T> &args) {
template <typename T>
void PadTestValidArguments(const int, const Arguments<T> &) { }
std::vector<Constraint> PadSetConstraints(const int) { return {}; }
+template <typename T>
+LocalMemSizeInfo PadComputeLocalMemSize(const int) {
+ return { [] (std::vector<size_t>) -> size_t { return 0; }, {} };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/invert.cpp b/src/tuning/kernels/invert.cpp
index 3dfeb508..3795da88 100644
--- a/src/tuning/kernels/invert.cpp
+++ b/src/tuning/kernels/invert.cpp
@@ -22,11 +22,11 @@ using double2 = clblast::double2;
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<half>, clblast::InvertTestValidArguments<half>, clblast::InvertSetConstraints, clblast::InvertSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<float>, clblast::InvertTestValidArguments<float>, clblast::InvertSetConstraints, clblast::InvertSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<double>, clblast::InvertTestValidArguments<double>, clblast::InvertSetConstraints, clblast::InvertSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<float2>, clblast::InvertTestValidArguments<float2>, clblast::InvertSetConstraints, clblast::InvertSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<double2>, clblast::InvertTestValidArguments<double2>, clblast::InvertSetConstraints, clblast::InvertSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<half>, clblast::InvertTestValidArguments<half>, clblast::InvertSetConstraints, clblast::InvertComputeLocalMemSize<half>, clblast::InvertSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<float>, clblast::InvertTestValidArguments<float>, clblast::InvertSetConstraints, clblast::InvertComputeLocalMemSize<float>, clblast::InvertSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<double>, clblast::InvertTestValidArguments<double>, clblast::InvertSetConstraints, clblast::InvertComputeLocalMemSize<double>, clblast::InvertSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<float2>, clblast::InvertTestValidArguments<float2>, clblast::InvertSetConstraints, clblast::InvertComputeLocalMemSize<float2>, clblast::InvertSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::InvertGetTunerDefaults, clblast::InvertGetTunerSettings<double2>, clblast::InvertTestValidArguments<double2>, clblast::InvertSetConstraints, clblast::InvertComputeLocalMemSize<double2>, clblast::InvertSetArguments<double2>); break;
}
return 0;
}
diff --git a/src/tuning/kernels/invert.hpp b/src/tuning/kernels/invert.hpp
index 0a0c9ce2..4f74674d 100644
--- a/src/tuning/kernels/invert.hpp
+++ b/src/tuning/kernels/invert.hpp
@@ -87,6 +87,15 @@ void InvertTestValidArguments(const int, const Arguments<T> &args) {
}
}
std::vector<Constraint> InvertSetConstraints(const int) { return {}; }
+template <typename T>
+LocalMemSizeInfo InvertComputeLocalMemSize(const int) {
+ return {
+ [] (std::vector<size_t> v) -> size_t {
+ return GetBytes(PrecisionValue<T>()) * (16 + v[0]) * 16;
+ },
+ {"LOCALPAD"}
+ };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/transpose_fast.cpp b/src/tuning/kernels/transpose_fast.cpp
index 6b37a31d..024f7385 100644
--- a/src/tuning/kernels/transpose_fast.cpp
+++ b/src/tuning/kernels/transpose_fast.cpp
@@ -22,11 +22,11 @@ using double2 = clblast::double2;
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<half>, clblast::TransposeTestValidArguments<half>, clblast::TransposeSetConstraints, clblast::TransposeSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<float>, clblast::TransposeTestValidArguments<float>, clblast::TransposeSetConstraints, clblast::TransposeSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<double>, clblast::TransposeTestValidArguments<double>, clblast::TransposeSetConstraints, clblast::TransposeSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<float2>, clblast::TransposeTestValidArguments<float2>, clblast::TransposeSetConstraints, clblast::TransposeSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<double2>, clblast::TransposeTestValidArguments<double2>, clblast::TransposeSetConstraints, clblast::TransposeSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<half>, clblast::TransposeTestValidArguments<half>, clblast::TransposeSetConstraints, clblast::TransposeComputeLocalMemSize<half>, clblast::TransposeSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<float>, clblast::TransposeTestValidArguments<float>, clblast::TransposeSetConstraints, clblast::TransposeComputeLocalMemSize<float>, clblast::TransposeSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<double>, clblast::TransposeTestValidArguments<double>, clblast::TransposeSetConstraints, clblast::TransposeComputeLocalMemSize<double>, clblast::TransposeSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<float2>, clblast::TransposeTestValidArguments<float2>, clblast::TransposeSetConstraints, clblast::TransposeComputeLocalMemSize<float2>, clblast::TransposeSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::TransposeGetTunerDefaults, clblast::TransposeGetTunerSettings<double2>, clblast::TransposeTestValidArguments<double2>, clblast::TransposeSetConstraints, clblast::TransposeComputeLocalMemSize<double2>, clblast::TransposeSetArguments<double2>); break;
}
return 0;
}
diff --git a/src/tuning/kernels/transpose_fast.hpp b/src/tuning/kernels/transpose_fast.hpp
index e8917ad2..c6e3f98d 100644
--- a/src/tuning/kernels/transpose_fast.hpp
+++ b/src/tuning/kernels/transpose_fast.hpp
@@ -79,6 +79,15 @@ TunerSettings TransposeGetTunerSettings(const int, const Arguments<T> &args) {
template <typename T>
void TransposeTestValidArguments(const int, const Arguments<T> &) { }
std::vector<Constraint> TransposeSetConstraints(const int) { return {}; }
+template <typename T>
+LocalMemSizeInfo TransposeComputeLocalMemSize(const int) {
+ return {
+ [] (std::vector<size_t> v) -> size_t {
+ return GetBytes(PrecisionValue<T>()) * v[1] * (v[1] * v[0]) * (v[0] + v[2]);
+ },
+ {"TRA_DIM", "TRA_WPT", "TRA_PAD"}
+ };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/transpose_pad.cpp b/src/tuning/kernels/transpose_pad.cpp
index fc7244f6..ffaa252b 100644
--- a/src/tuning/kernels/transpose_pad.cpp
+++ b/src/tuning/kernels/transpose_pad.cpp
@@ -22,11 +22,11 @@ using double2 = clblast::double2;
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<half>, clblast::PadtransposeTestValidArguments<half>, clblast::PadtransposeSetConstraints, clblast::PadtransposeSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<float>, clblast::PadtransposeTestValidArguments<float>, clblast::PadtransposeSetConstraints, clblast::PadtransposeSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<double>, clblast::PadtransposeTestValidArguments<double>, clblast::PadtransposeSetConstraints, clblast::PadtransposeSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<float2>, clblast::PadtransposeTestValidArguments<float2>, clblast::PadtransposeSetConstraints, clblast::PadtransposeSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<double2>, clblast::PadtransposeTestValidArguments<double2>, clblast::PadtransposeSetConstraints, clblast::PadtransposeSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<half>, clblast::PadtransposeTestValidArguments<half>, clblast::PadtransposeSetConstraints, clblast::PadtransposeComputeLocalMemSize<half>, clblast::PadtransposeSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<float>, clblast::PadtransposeTestValidArguments<float>, clblast::PadtransposeSetConstraints, clblast::PadtransposeComputeLocalMemSize<float>, clblast::PadtransposeSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<double>, clblast::PadtransposeTestValidArguments<double>, clblast::PadtransposeSetConstraints, clblast::PadtransposeComputeLocalMemSize<double>, clblast::PadtransposeSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<float2>, clblast::PadtransposeTestValidArguments<float2>, clblast::PadtransposeSetConstraints, clblast::PadtransposeComputeLocalMemSize<float2>, clblast::PadtransposeSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::PadtransposeGetTunerDefaults, clblast::PadtransposeGetTunerSettings<double2>, clblast::PadtransposeTestValidArguments<double2>, clblast::PadtransposeSetConstraints, clblast::PadtransposeComputeLocalMemSize<double2>, clblast::PadtransposeSetArguments<double2>); break;
}
return 0;
}
diff --git a/src/tuning/kernels/transpose_pad.hpp b/src/tuning/kernels/transpose_pad.hpp
index 8d24a0dc..ebc0e4fb 100644
--- a/src/tuning/kernels/transpose_pad.hpp
+++ b/src/tuning/kernels/transpose_pad.hpp
@@ -78,6 +78,15 @@ TunerSettings PadtransposeGetTunerSettings(const int, const Arguments<T> &args)
template <typename T>
void PadtransposeTestValidArguments(const int, const Arguments<T> &) { }
std::vector<Constraint> PadtransposeSetConstraints(const int) { return {}; }
+template <typename T>
+LocalMemSizeInfo PadtransposeComputeLocalMemSize(const int) {
+ return {
+ [] (std::vector<size_t> v) -> size_t {
+ return GetBytes(PrecisionValue<T>()) * (v[1] * v[0]) * (v[1] * v[0] + v[2]);
+ },
+ {"PADTRA_TILE", "PADTRA_WPT", "PADTRA_PAD"}
+ };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/xaxpy.cpp b/src/tuning/kernels/xaxpy.cpp
index 6a95600d..681876ea 100644
--- a/src/tuning/kernels/xaxpy.cpp
+++ b/src/tuning/kernels/xaxpy.cpp
@@ -22,11 +22,11 @@ using double2 = clblast::double2;
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<half>, clblast::XaxpyTestValidArguments<half>, clblast::XaxpySetConstraints, clblast::XaxpySetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<float>, clblast::XaxpyTestValidArguments<float>, clblast::XaxpySetConstraints, clblast::XaxpySetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<double>, clblast::XaxpyTestValidArguments<double>, clblast::XaxpySetConstraints, clblast::XaxpySetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<float2>, clblast::XaxpyTestValidArguments<float2>, clblast::XaxpySetConstraints, clblast::XaxpySetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<double2>, clblast::XaxpyTestValidArguments<double2>, clblast::XaxpySetConstraints, clblast::XaxpySetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<half>, clblast::XaxpyTestValidArguments<half>, clblast::XaxpySetConstraints, clblast::XaxpyComputeLocalMemSize<half>, clblast::XaxpySetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<float>, clblast::XaxpyTestValidArguments<float>, clblast::XaxpySetConstraints, clblast::XaxpyComputeLocalMemSize<float>, clblast::XaxpySetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<double>, clblast::XaxpyTestValidArguments<double>, clblast::XaxpySetConstraints, clblast::XaxpyComputeLocalMemSize<double>, clblast::XaxpySetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<float2>, clblast::XaxpyTestValidArguments<float2>, clblast::XaxpySetConstraints, clblast::XaxpyComputeLocalMemSize<float2>, clblast::XaxpySetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::XaxpyGetTunerDefaults, clblast::XaxpyGetTunerSettings<double2>, clblast::XaxpyTestValidArguments<double2>, clblast::XaxpySetConstraints, clblast::XaxpyComputeLocalMemSize<double2>, clblast::XaxpySetArguments<double2>); break;
}
return 0;
}
diff --git a/src/tuning/kernels/xaxpy.hpp b/src/tuning/kernels/xaxpy.hpp
index 24550ed9..ab2c45f0 100644
--- a/src/tuning/kernels/xaxpy.hpp
+++ b/src/tuning/kernels/xaxpy.hpp
@@ -81,6 +81,10 @@ void XaxpyTestValidArguments(const int, const Arguments<T> &args) {
}
}
std::vector<Constraint> XaxpySetConstraints(const int) { return {}; }
+template <typename T>
+LocalMemSizeInfo XaxpyComputeLocalMemSize(const int) {
+ return { [] (std::vector<size_t>) -> size_t { return 0; }, {} };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/xdot.cpp b/src/tuning/kernels/xdot.cpp
index 6d10c4d8..a481f23b 100644
--- a/src/tuning/kernels/xdot.cpp
+++ b/src/tuning/kernels/xdot.cpp
@@ -24,11 +24,11 @@ template <int V>
void StartVariation(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<half>, clblast::XdotTestValidArguments<half>, clblast::XdotSetConstraints, clblast::XdotSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<float>, clblast::XdotTestValidArguments<float>, clblast::XdotSetConstraints, clblast::XdotSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<double>, clblast::XdotTestValidArguments<double>, clblast::XdotSetConstraints, clblast::XdotSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<float2>, clblast::XdotTestValidArguments<float2>, clblast::XdotSetConstraints, clblast::XdotSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<double2>, clblast::XdotTestValidArguments<double2>, clblast::XdotSetConstraints, clblast::XdotSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<half>, clblast::XdotTestValidArguments<half>, clblast::XdotSetConstraints, clblast::XdotComputeLocalMemSize<half>, clblast::XdotSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<float>, clblast::XdotTestValidArguments<float>, clblast::XdotSetConstraints, clblast::XdotComputeLocalMemSize<float>, clblast::XdotSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<double>, clblast::XdotTestValidArguments<double>, clblast::XdotSetConstraints, clblast::XdotComputeLocalMemSize<double>, clblast::XdotSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<float2>, clblast::XdotTestValidArguments<float2>, clblast::XdotSetConstraints, clblast::XdotComputeLocalMemSize<float2>, clblast::XdotSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, V, clblast::XdotGetTunerDefaults, clblast::XdotGetTunerSettings<double2>, clblast::XdotTestValidArguments<double2>, clblast::XdotSetConstraints, clblast::XdotComputeLocalMemSize<double2>, clblast::XdotSetArguments<double2>); break;
}
}
diff --git a/src/tuning/kernels/xdot.hpp b/src/tuning/kernels/xdot.hpp
index 15673c79..901d8fd0 100644
--- a/src/tuning/kernels/xdot.hpp
+++ b/src/tuning/kernels/xdot.hpp
@@ -76,6 +76,14 @@ TunerSettings XdotGetTunerSettings(const int V, const Arguments<T> &args) {
template <typename T>
void XdotTestValidArguments(const int, const Arguments<T> &) { }
std::vector<Constraint> XdotSetConstraints(const int) { return {}; }
+template <typename T>
+LocalMemSizeInfo XdotComputeLocalMemSize(const int V) {
+ return {
+ [] (std::vector<size_t> v) -> size_t {
+ return GetBytes(PrecisionValue<T>()) * v[0];
+ },
+ {"WGS"+std::to_string(V)}
+ };}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/xgemm.cpp b/src/tuning/kernels/xgemm.cpp
index d365ce6d..85948373 100644
--- a/src/tuning/kernels/xgemm.cpp
+++ b/src/tuning/kernels/xgemm.cpp
@@ -23,11 +23,11 @@ template <int V>
void StartVariation(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<half>, clblast::XgemmTestValidArguments<half>, clblast::XgemmSetConstraints, clblast::XgemmSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<float>, clblast::XgemmTestValidArguments<float>, clblast::XgemmSetConstraints, clblast::XgemmSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<double>, clblast::XgemmTestValidArguments<double>, clblast::XgemmSetConstraints, clblast::XgemmSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<float2>, clblast::XgemmTestValidArguments<float2>, clblast::XgemmSetConstraints, clblast::XgemmSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<double2>, clblast::XgemmTestValidArguments<double2>, clblast::XgemmSetConstraints, clblast::XgemmSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<half>, clblast::XgemmTestValidArguments<half>, clblast::XgemmSetConstraints, clblast::XgemmComputeLocalMemSize<half>, clblast::XgemmSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<float>, clblast::XgemmTestValidArguments<float>, clblast::XgemmSetConstraints, clblast::XgemmComputeLocalMemSize<float>, clblast::XgemmSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<double>, clblast::XgemmTestValidArguments<double>, clblast::XgemmSetConstraints, clblast::XgemmComputeLocalMemSize<double>, clblast::XgemmSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<float2>, clblast::XgemmTestValidArguments<float2>, clblast::XgemmSetConstraints, clblast::XgemmComputeLocalMemSize<float2>, clblast::XgemmSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, V, clblast::XgemmGetTunerDefaults, clblast::XgemmGetTunerSettings<double2>, clblast::XgemmTestValidArguments<double2>, clblast::XgemmSetConstraints, clblast::XgemmComputeLocalMemSize<double2>, clblast::XgemmSetArguments<double2>); break;
}
}
diff --git a/src/tuning/kernels/xgemm.hpp b/src/tuning/kernels/xgemm.hpp
index 66e197e1..5f191ba9 100644
--- a/src/tuning/kernels/xgemm.hpp
+++ b/src/tuning/kernels/xgemm.hpp
@@ -145,6 +145,15 @@ std::vector<Constraint> XgemmSetConstraints(const int V) {
}
return constraints;
}
+template <typename T>
+LocalMemSizeInfo XgemmComputeLocalMemSize(const int) {
+ return {
+ [] (std::vector<size_t> v) -> size_t {
+ return GetBytes(PrecisionValue<T>()) * ((v[0]*v[1]*v[2]) + (v[3]*v[4]*v[5]));
+ },
+ {"SA", "KWG", "MWG", "SB", "KWG", "NWG"}
+ };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/xgemm_direct.cpp b/src/tuning/kernels/xgemm_direct.cpp
index 7298a6c3..73c2217c 100644
--- a/src/tuning/kernels/xgemm_direct.cpp
+++ b/src/tuning/kernels/xgemm_direct.cpp
@@ -23,11 +23,11 @@ template <int V>
void StartVariation(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<half>, clblast::XgemmDirectTestValidArguments<half>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<float>, clblast::XgemmDirectTestValidArguments<float>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<double>, clblast::XgemmDirectTestValidArguments<double>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<float2>, clblast::XgemmDirectTestValidArguments<float2>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<double2>, clblast::XgemmDirectTestValidArguments<double2>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<half>, clblast::XgemmDirectTestValidArguments<half>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectComputeLocalMemSize<half>, clblast::XgemmDirectSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<float>, clblast::XgemmDirectTestValidArguments<float>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectComputeLocalMemSize<float>, clblast::XgemmDirectSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<double>, clblast::XgemmDirectTestValidArguments<double>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectComputeLocalMemSize<double>, clblast::XgemmDirectSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<float2>, clblast::XgemmDirectTestValidArguments<float2>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectComputeLocalMemSize<float2>, clblast::XgemmDirectSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, V, clblast::XgemmDirectGetTunerDefaults, clblast::XgemmDirectGetTunerSettings<double2>, clblast::XgemmDirectTestValidArguments<double2>, clblast::XgemmDirectSetConstraints, clblast::XgemmDirectComputeLocalMemSize<double2>, clblast::XgemmDirectSetArguments<double2>); break;
}
}
diff --git a/src/tuning/kernels/xgemm_direct.hpp b/src/tuning/kernels/xgemm_direct.hpp
index ecb10bc6..baa063c0 100644
--- a/src/tuning/kernels/xgemm_direct.hpp
+++ b/src/tuning/kernels/xgemm_direct.hpp
@@ -135,6 +135,15 @@ std::vector<Constraint> XgemmDirectSetConstraints(const int V) {
}
return constraints;
}
+template <typename T>
+LocalMemSizeInfo XgemmDirectComputeLocalMemSize(const int) {
+ return {
+ [] (std::vector<size_t> v) -> size_t {
+ return GetBytes(PrecisionValue<T>()) * ((v[0]*(v[0] + v[1]) + v[0]*(v[0] + v[2])));
+ },
+ {"WGD", "PADA", "PADB"}
+ };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/xgemv.cpp b/src/tuning/kernels/xgemv.cpp
index 9e45d73f..6505a081 100644
--- a/src/tuning/kernels/xgemv.cpp
+++ b/src/tuning/kernels/xgemv.cpp
@@ -23,11 +23,11 @@ template <int V>
void StartVariation(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<half>, clblast::XgemvTestValidArguments<half>, clblast::XgemvSetConstraints, clblast::XgemvSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<float>, clblast::XgemvTestValidArguments<float>, clblast::XgemvSetConstraints, clblast::XgemvSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<double>, clblast::XgemvTestValidArguments<double>, clblast::XgemvSetConstraints, clblast::XgemvSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<float2>, clblast::XgemvTestValidArguments<float2>, clblast::XgemvSetConstraints, clblast::XgemvSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<double2>, clblast::XgemvTestValidArguments<double2>, clblast::XgemvSetConstraints, clblast::XgemvSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<half>, clblast::XgemvTestValidArguments<half>, clblast::XgemvSetConstraints, clblast::XgemvComputeLocalMemSize<half>, clblast::XgemvSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<float>, clblast::XgemvTestValidArguments<float>, clblast::XgemvSetConstraints, clblast::XgemvComputeLocalMemSize<float>, clblast::XgemvSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<double>, clblast::XgemvTestValidArguments<double>, clblast::XgemvSetConstraints, clblast::XgemvComputeLocalMemSize<double>, clblast::XgemvSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<float2>, clblast::XgemvTestValidArguments<float2>, clblast::XgemvSetConstraints, clblast::XgemvComputeLocalMemSize<float2>, clblast::XgemvSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, V, clblast::XgemvGetTunerDefaults, clblast::XgemvGetTunerSettings<double2>, clblast::XgemvTestValidArguments<double2>, clblast::XgemvSetConstraints, clblast::XgemvComputeLocalMemSize<double2>, clblast::XgemvSetArguments<double2>); break;
}
}
diff --git a/src/tuning/kernels/xgemv.hpp b/src/tuning/kernels/xgemv.hpp
index e44efe32..c582816e 100644
--- a/src/tuning/kernels/xgemv.hpp
+++ b/src/tuning/kernels/xgemv.hpp
@@ -109,6 +109,23 @@ std::vector<Constraint> XgemvSetConstraints(const int V) {
}
return constraints;
}
+template <typename T>
+LocalMemSizeInfo XgemvComputeLocalMemSize(const int V) {
+ if (V == 1 || V == 2) {
+ return {
+ [V] (std::vector<size_t> v) -> size_t {
+ return GetBytes(PrecisionValue<T>()) * v[0];
+ },
+ {"WGS" + std::to_string(V)}
+ };
+ }
+ return {
+ [V] (std::vector<size_t> v) -> size_t {
+ return GetBytes(PrecisionValue<T>()) * (v[0] + v[1] * v[2]);
+ },
+ {"WGS3", "WPT3", "WGS3"}
+ };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/kernels/xger.cpp b/src/tuning/kernels/xger.cpp
index 6dfc9ffa..e4c9fc03 100644
--- a/src/tuning/kernels/xger.cpp
+++ b/src/tuning/kernels/xger.cpp
@@ -22,11 +22,11 @@ using double2 = clblast::double2;
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args)) {
- case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<half>, clblast::XgerTestValidArguments<half>, clblast::XgerSetConstraints, clblast::XgerSetArguments<half>); break;
- case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<float>, clblast::XgerTestValidArguments<float>, clblast::XgerSetConstraints, clblast::XgerSetArguments<float>); break;
- case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<double>, clblast::XgerTestValidArguments<double>, clblast::XgerSetConstraints, clblast::XgerSetArguments<double>); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<float2>, clblast::XgerTestValidArguments<float2>, clblast::XgerSetConstraints, clblast::XgerSetArguments<float2>); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<double2>, clblast::XgerTestValidArguments<double2>, clblast::XgerSetConstraints, clblast::XgerSetArguments<double2>); break;
+ case clblast::Precision::kHalf: clblast::Tuner<half>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<half>, clblast::XgerTestValidArguments<half>, clblast::XgerSetConstraints, clblast::XgerComputeLocalMemSize<half>, clblast::XgerSetArguments<half>); break;
+ case clblast::Precision::kSingle: clblast::Tuner<float>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<float>, clblast::XgerTestValidArguments<float>, clblast::XgerSetConstraints, clblast::XgerComputeLocalMemSize<float>, clblast::XgerSetArguments<float>); break;
+ case clblast::Precision::kDouble: clblast::Tuner<double>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<double>, clblast::XgerTestValidArguments<double>, clblast::XgerSetConstraints, clblast::XgerComputeLocalMemSize<double>, clblast::XgerSetArguments<double>); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<float2>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<float2>, clblast::XgerTestValidArguments<float2>, clblast::XgerSetConstraints, clblast::XgerComputeLocalMemSize<float2>, clblast::XgerSetArguments<float2>); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<double2>(argc, argv, 0, clblast::XgerGetTunerDefaults, clblast::XgerGetTunerSettings<double2>, clblast::XgerTestValidArguments<double2>, clblast::XgerSetConstraints, clblast::XgerComputeLocalMemSize<double2>, clblast::XgerSetArguments<double2>); break;
}
return 0;
}
diff --git a/src/tuning/kernels/xger.hpp b/src/tuning/kernels/xger.hpp
index afd2f36e..0473572d 100644
--- a/src/tuning/kernels/xger.hpp
+++ b/src/tuning/kernels/xger.hpp
@@ -79,6 +79,10 @@ TunerSettings XgerGetTunerSettings(const int, const Arguments<T> &args) {
template <typename T>
void XgerTestValidArguments(const int, const Arguments<T> &) { }
std::vector<Constraint> XgerSetConstraints(const int) { return {}; }
+template <typename T>
+LocalMemSizeInfo XgerComputeLocalMemSize(const int) {
+ return { [] (std::vector<size_t>) -> size_t { return 0; }, {} };
+}
// Sets the kernel's arguments
template <typename T>
diff --git a/src/tuning/tuning.cpp b/src/tuning/tuning.cpp
index b5e01f65..dd4a83e6 100644
--- a/src/tuning/tuning.cpp
+++ b/src/tuning/tuning.cpp
@@ -93,6 +93,7 @@ void Tuner(int argc, char* argv[], const int V,
GetTunerSettingsFunc<T> GetTunerSettings,
TestValidArgumentsFunc<T> TestValidArguments,
SetConstraintsFunc SetConstraints,
+ ComputeLocalMemSizeFunc<T> ComputeLocalMemSize,
SetArgumentsFunc<T> SetArguments) {
constexpr auto kSeed = 42; // fixed seed for reproducibility
@@ -171,7 +172,8 @@ void Tuner(int argc, char* argv[], const int V,
}
// Sets the tunable parameters and their possible values
- auto configurations = SetConfigurations(settings.parameters, SetConstraints(V));
+ auto configurations = SetConfigurations(device, settings.parameters,
+ SetConstraints(V), ComputeLocalMemSize(V));
printf("* Found %s%zu configuration(s)%s\n",
kPrintMessage.c_str(), configurations.size(), kPrintEnd.c_str());
@@ -380,11 +382,11 @@ void Tuner(int argc, char* argv[], const int V,
}
// Compiles the above function
-template void Tuner<half>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<half> GetTunerSettings, TestValidArgumentsFunc<half> TestValidArguments, SetConstraintsFunc SetConstraints, SetArgumentsFunc<half> SetArguments);
-template void Tuner<float>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<float> GetTunerSettings, TestValidArgumentsFunc<float> TestValidArguments, SetConstraintsFunc SetConstraints, SetArgumentsFunc<float> SetArguments);
-template void Tuner<double>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<double> GetTunerSettings, TestValidArgumentsFunc<double> TestValidArguments, SetConstraintsFunc SetConstraints, SetArgumentsFunc<double> SetArguments);
-template void Tuner<float2>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<float2> GetTunerSettings, TestValidArgumentsFunc<float2> TestValidArguments, SetConstraintsFunc SetConstraints, SetArgumentsFunc<float2> SetArguments);
-template void Tuner<double2>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<double2> GetTunerSettings, TestValidArgumentsFunc<double2> TestValidArguments, SetConstraintsFunc SetConstraints, SetArgumentsFunc<double2> SetArguments);
+template void Tuner<half>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<half> GetTunerSettings, TestValidArgumentsFunc<half> TestValidArguments, SetConstraintsFunc SetConstraints, ComputeLocalMemSizeFunc<half> ComputeLocalMemSize, SetArgumentsFunc<half> SetArguments);
+template void Tuner<float>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<float> GetTunerSettings, TestValidArgumentsFunc<float> TestValidArguments, SetConstraintsFunc SetConstraints, ComputeLocalMemSizeFunc<float> ComputeLocalMemSize, SetArgumentsFunc<float> SetArguments);
+template void Tuner<double>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<double> GetTunerSettings, TestValidArgumentsFunc<double> TestValidArguments, SetConstraintsFunc SetConstraints, ComputeLocalMemSizeFunc<double> ComputeLocalMemSize, SetArgumentsFunc<double> SetArguments);
+template void Tuner<float2>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<float2> GetTunerSettings, TestValidArgumentsFunc<float2> TestValidArguments, SetConstraintsFunc SetConstraints, ComputeLocalMemSizeFunc<float2> ComputeLocalMemSize, SetArgumentsFunc<float2> SetArguments);
+template void Tuner<double2>(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDefaults, GetTunerSettingsFunc<double2> GetTunerSettings, TestValidArgumentsFunc<double2> TestValidArguments, SetConstraintsFunc SetConstraints, ComputeLocalMemSizeFunc<double2> ComputeLocalMemSize, SetArgumentsFunc<double2> SetArguments);
// =================================================================================================
} // namespace clblast
diff --git a/src/tuning/tuning.hpp b/src/tuning/tuning.hpp
index cbecc300..37a042ff 100644
--- a/src/tuning/tuning.hpp
+++ b/src/tuning/tuning.hpp
@@ -108,6 +108,8 @@ template <typename T>
using TestValidArgumentsFunc = std::function<void(const int V, const Arguments<T> &args)>;
using SetConstraintsFunc = std::function<std::vector<Constraint>(const int V)>;
template <typename T>
+using ComputeLocalMemSizeFunc = std::function<LocalMemSizeInfo(const int V)>;
+template <typename T>
using SetArgumentsFunc = std::function<void(const int V, Kernel &kernel, const Arguments<T> &args, std::vector<Buffer<T>>& buffers)>;
// Function to get command-line argument, set-up the input buffers, configure the tuner, and collect
@@ -119,6 +121,7 @@ void Tuner(int argc, char* argv[], const int V,
GetTunerSettingsFunc<T> GetTunerSettings,
TestValidArgumentsFunc<T> TestValidArguments,
SetConstraintsFunc SetConstraints,
+ ComputeLocalMemSizeFunc<T> ComputeLocalMemSize,
SetArgumentsFunc<T> SetArguments);
// Function to run the tuners through the CLBlast API, no I/O
@@ -128,6 +131,7 @@ StatusCode TunerAPI(Queue &queue, const Arguments<T> &args, const int V,
const GetTunerSettingsFunc<T> GetTunerSettings,
const TestValidArgumentsFunc<T> TestValidArguments,
const SetConstraintsFunc SetConstraints,
+ const ComputeLocalMemSizeFunc<T> ComputeLocalMemSize,
const SetArgumentsFunc<T> SetArguments,
std::unordered_map<std::string,size_t> &parameters);
diff --git a/src/tuning/tuning_api.cpp b/src/tuning/tuning_api.cpp
index d03b428c..f37b3600 100644
--- a/src/tuning/tuning_api.cpp
+++ b/src/tuning/tuning_api.cpp
@@ -40,7 +40,7 @@ StatusCode TuneXaxpy(RawCommandQueue * queue, const size_t n,
auto args = Arguments<T>(); args.fraction = fraction; args.n = n;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 0, XaxpyGetTunerDefaults, XaxpyGetTunerSettings<T>,
- XaxpyTestValidArguments<T>, XaxpySetConstraints, XaxpySetArguments<T>, parameters);
+ XaxpyTestValidArguments<T>, XaxpySetConstraints, XaxpyComputeLocalMemSize<T>, XaxpySetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneXaxpy<half>(RawCommandQueue*, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneXaxpy<float>(RawCommandQueue*, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -54,10 +54,10 @@ StatusCode TuneXdot(RawCommandQueue * queue, const size_t n,
auto args = Arguments<T>(); args.fraction = fraction; args.n = n;
auto queue_cpp = Queue(*queue);
auto status = TunerAPI<T>(queue_cpp, args, 1, XdotGetTunerDefaults, XdotGetTunerSettings<T>,
- XdotTestValidArguments<T>, XdotSetConstraints, XdotSetArguments<T>, parameters);
+ XdotTestValidArguments<T>, XdotSetConstraints, XdotComputeLocalMemSize<T>, XdotSetArguments<T>, parameters);
if (status != StatusCode::kSuccess) { return status; }
return TunerAPI<T>(queue_cpp, args, 2, XdotGetTunerDefaults, XdotGetTunerSettings<T>,
- XdotTestValidArguments<T>, XdotSetConstraints, XdotSetArguments<T>, parameters);
+ XdotTestValidArguments<T>, XdotSetConstraints, XdotComputeLocalMemSize<T>, XdotSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneXdot<half>(RawCommandQueue*, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneXdot<float>(RawCommandQueue*, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -71,13 +71,13 @@ StatusCode TuneXgemv(RawCommandQueue * queue, const size_t m, const size_t n,
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n;
auto queue_cpp = Queue(*queue);
auto status = TunerAPI<T>(queue_cpp, args, 1, XgemvGetTunerDefaults, XgemvGetTunerSettings<T>,
- XgemvTestValidArguments<T>, XgemvSetConstraints, XgemvSetArguments<T>, parameters);
+ XgemvTestValidArguments<T>, XgemvSetConstraints, XgemvComputeLocalMemSize<T>, XgemvSetArguments<T>, parameters);
if (status != StatusCode::kSuccess) { return status; }
status = TunerAPI<T>(queue_cpp, args, 2, XgemvGetTunerDefaults, XgemvGetTunerSettings<T>,
- XgemvTestValidArguments<T>, XgemvSetConstraints, XgemvSetArguments<T>, parameters);
+ XgemvTestValidArguments<T>, XgemvSetConstraints, XgemvComputeLocalMemSize<T>, XgemvSetArguments<T>, parameters);
if (status != StatusCode::kSuccess) { return status; }
return TunerAPI<T>(queue_cpp, args, 3, XgemvGetTunerDefaults, XgemvGetTunerSettings<T>,
- XgemvTestValidArguments<T>, XgemvSetConstraints, XgemvSetArguments<T>, parameters);
+ XgemvTestValidArguments<T>, XgemvSetConstraints, XgemvComputeLocalMemSize<T>, XgemvSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneXgemv<half>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneXgemv<float>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -91,7 +91,7 @@ StatusCode TuneXger(RawCommandQueue * queue, const size_t m, const size_t n,
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 0, XgerGetTunerDefaults, XgerGetTunerSettings<T>,
- XgerTestValidArguments<T>, XgerSetConstraints, XgerSetArguments<T>, parameters);
+ XgerTestValidArguments<T>, XgerSetConstraints, XgerComputeLocalMemSize<T>, XgerSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneXger<half>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneXger<float>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -105,7 +105,7 @@ StatusCode TuneXgemm(RawCommandQueue * queue, const size_t m, const size_t n, co
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n; args.k = k;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 2, XgemmGetTunerDefaults, XgemmGetTunerSettings<T>,
- XgemmTestValidArguments<T>, XgemmSetConstraints, XgemmSetArguments<T>, parameters);
+ XgemmTestValidArguments<T>, XgemmSetConstraints, XgemmComputeLocalMemSize<T>, XgemmSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneXgemm<half>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneXgemm<float>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -119,7 +119,7 @@ StatusCode TuneXgemmDirect(RawCommandQueue * queue, const size_t m, const size_t
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n; args.k = k;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 2, XgemmDirectGetTunerDefaults, XgemmDirectGetTunerSettings<T>,
- XgemmDirectTestValidArguments<T>, XgemmDirectSetConstraints, XgemmDirectSetArguments<T>, parameters);
+ XgemmDirectTestValidArguments<T>, XgemmDirectSetConstraints, XgemmDirectComputeLocalMemSize<T>, XgemmDirectSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneXgemmDirect<half>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneXgemmDirect<float>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -133,7 +133,7 @@ StatusCode TuneCopy(RawCommandQueue * queue, const size_t m, const size_t n,
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 0, CopyGetTunerDefaults, CopyGetTunerSettings<T>,
- CopyTestValidArguments<T>, CopySetConstraints, CopySetArguments<T>, parameters);
+ CopyTestValidArguments<T>, CopySetConstraints, CopyComputeLocalMemSize<T>, CopySetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneCopy<half>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneCopy<float>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -147,7 +147,7 @@ StatusCode TunePad(RawCommandQueue * queue, const size_t m, const size_t n,
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 0, PadGetTunerDefaults, PadGetTunerSettings<T>,
- PadTestValidArguments<T>, PadSetConstraints, PadSetArguments<T>, parameters);
+ PadTestValidArguments<T>, PadSetConstraints, PadComputeLocalMemSize<T>, PadSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TunePad<half>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TunePad<float>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -161,7 +161,7 @@ StatusCode TuneTranspose(RawCommandQueue * queue, const size_t m, const size_t n
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 0, TransposeGetTunerDefaults, TransposeGetTunerSettings<T>,
- TransposeTestValidArguments<T>, TransposeSetConstraints, TransposeSetArguments<T>, parameters);
+ TransposeTestValidArguments<T>, TransposeSetConstraints, TransposeComputeLocalMemSize<T>, TransposeSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneTranspose<half>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneTranspose<float>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -175,7 +175,7 @@ StatusCode TunePadtranspose(RawCommandQueue * queue, const size_t m, const size_
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 0, PadtransposeGetTunerDefaults, PadtransposeGetTunerSettings<T>,
- PadtransposeTestValidArguments<T>, PadtransposeSetConstraints, PadtransposeSetArguments<T>, parameters);
+ PadtransposeTestValidArguments<T>, PadtransposeSetConstraints, PadtransposeComputeLocalMemSize<T>, PadtransposeSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TunePadtranspose<half>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TunePadtranspose<float>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -189,7 +189,7 @@ StatusCode TuneInvert(RawCommandQueue * queue, const size_t m, const size_t n, c
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n; args.k = k;
auto queue_cpp = Queue(*queue);
return TunerAPI<T>(queue_cpp, args, 0, InvertGetTunerDefaults, InvertGetTunerSettings<T>,
- InvertTestValidArguments<T>, InvertSetConstraints, InvertSetArguments<T>, parameters);
+ InvertTestValidArguments<T>, InvertSetConstraints, InvertComputeLocalMemSize<T>, InvertSetArguments<T>, parameters);
}
template StatusCode PUBLIC_API TuneInvert<half>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode PUBLIC_API TuneInvert<float>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
@@ -206,6 +206,7 @@ StatusCode TunerAPI(Queue &queue, const Arguments<T> &args, const int V,
const GetTunerSettingsFunc<T> GetTunerSettings,
const TestValidArgumentsFunc<T> TestValidArguments,
const SetConstraintsFunc SetConstraints,
+ const ComputeLocalMemSizeFunc<T> ComputeLocalMemSize,
const SetArgumentsFunc<T> SetArguments,
std::unordered_map<std::string,size_t> &parameters) {
@@ -260,7 +261,8 @@ StatusCode TunerAPI(Queue &queue, const Arguments<T> &args, const int V,
}
// Sets the tunable parameters and their possible values
- auto configurations = SetConfigurations(settings.parameters, SetConstraints(V));
+ auto configurations = SetConfigurations(device, settings.parameters,
+ SetConstraints(V), ComputeLocalMemSize(V));
// Select the search method (full search or a random fraction)
if (args.fraction != 0.0 && args.fraction != 1.0) {
@@ -375,11 +377,11 @@ StatusCode TunerAPI(Queue &queue, const Arguments<T> &args, const int V,
}
// Compiles the above function
-template StatusCode TunerAPI<half>(Queue &queue, const Arguments<half> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<half> GetTunerSettings, const TestValidArgumentsFunc<half> TestValidArguments, const SetConstraintsFunc SetConstraints, const SetArgumentsFunc<half> SetArguments, std::unordered_map<std::string,size_t>&);
-template StatusCode TunerAPI<float>(Queue &queue, const Arguments<float> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<float> GetTunerSettings, const TestValidArgumentsFunc<float> TestValidArguments, const SetConstraintsFunc SetConstraints, const SetArgumentsFunc<float> SetArguments, std::unordered_map<std::string,size_t>&);
-template StatusCode TunerAPI<double>(Queue &queue, const Arguments<double> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<double> GetTunerSettings, const TestValidArgumentsFunc<double> TestValidArguments, const SetConstraintsFunc SetConstraints, const SetArgumentsFunc<double> SetArguments, std::unordered_map<std::string,size_t>&);
-template StatusCode TunerAPI<float2>(Queue &queue, const Arguments<float2> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<float2> GetTunerSettings, const TestValidArgumentsFunc<float2> TestValidArguments, const SetConstraintsFunc SetConstraints, const SetArgumentsFunc<float2> SetArguments, std::unordered_map<std::string,size_t>&);
-template StatusCode TunerAPI<double2>(Queue &queue, const Arguments<double2> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<double2> GetTunerSettings, const TestValidArgumentsFunc<double2> TestValidArguments, const SetConstraintsFunc SetConstraints, const SetArgumentsFunc<double2> SetArguments, std::unordered_map<std::string,size_t>&);
+template StatusCode TunerAPI<half>(Queue &queue, const Arguments<half> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<half> GetTunerSettings, const TestValidArgumentsFunc<half> TestValidArguments, const SetConstraintsFunc SetConstraints, const ComputeLocalMemSizeFunc<half> ComputeLocalMemSize, const SetArgumentsFunc<half> SetArguments, std::unordered_map<std::string,size_t>&);
+template StatusCode TunerAPI<float>(Queue &queue, const Arguments<float> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<float> GetTunerSettings, const TestValidArgumentsFunc<float> TestValidArguments, const SetConstraintsFunc SetConstraints, const ComputeLocalMemSizeFunc<float> ComputeLocalMemSize, const SetArgumentsFunc<float> SetArguments, std::unordered_map<std::string,size_t>&);
+template StatusCode TunerAPI<double>(Queue &queue, const Arguments<double> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<double> GetTunerSettings, const TestValidArgumentsFunc<double> TestValidArguments, const SetConstraintsFunc SetConstraints, const ComputeLocalMemSizeFunc<double> ComputeLocalMemSize, const SetArgumentsFunc<double> SetArguments, std::unordered_map<std::string,size_t>&);
+template StatusCode TunerAPI<float2>(Queue &queue, const Arguments<float2> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<float2> GetTunerSettings, const TestValidArgumentsFunc<float2> TestValidArguments, const SetConstraintsFunc SetConstraints, const ComputeLocalMemSizeFunc<float2> ComputeLocalMemSize, const SetArgumentsFunc<float2> SetArguments, std::unordered_map<std::string,size_t>&);
+template StatusCode TunerAPI<double2>(Queue &queue, const Arguments<double2> &args, const int V, const GetTunerDefaultsFunc GetTunerDefaults, const GetTunerSettingsFunc<double2> GetTunerSettings, const TestValidArgumentsFunc<double2> TestValidArguments, const SetConstraintsFunc SetConstraints, const ComputeLocalMemSizeFunc<double2> ComputeLocalMemSize, const SetArgumentsFunc<double2> SetArguments, std::unordered_map<std::string,size_t>&);
// =================================================================================================
} // namespace clblast