diff options
author | mcian <mcian86@gmail.com> | 2017-07-17 12:00:25 +0200 |
---|---|---|
committer | mcian <mcian86@gmail.com> | 2017-07-17 12:00:25 +0200 |
commit | 8131e68664e02c8a1bc5a0f5598294fd3bc5b974 (patch) | |
tree | 62b40c06ed312a98e2f64a71f66053df630a5918 /src/tuning/tuning.hpp | |
parent | f2477f663672fd37301d6e2ce4646519f71d5cce (diff) |
Add PSO parameters support and search strategy selection from command line
Diffstat (limited to 'src/tuning/tuning.hpp')
-rw-r--r-- | src/tuning/tuning.hpp | 44 |
1 files changed, 42 insertions, 2 deletions
diff --git a/src/tuning/tuning.hpp b/src/tuning/tuning.hpp index 25504430..35b320cb 100644 --- a/src/tuning/tuning.hpp +++ b/src/tuning/tuning.hpp @@ -48,6 +48,11 @@ void Tuner(int argc, char* argv[]) { if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); } if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); } if (o == kArgBatchCount) { args.batch_count = GetArgument(command_line_args, help, kArgBatchCount, C::DefaultBatchCount()); } + if (o == tStrategy) {args.tStrategy = GetArgument(command_line_args, help, tStrategy, DEFAULT_STRATEGY); } + if (o == psoSwarmSize) {args.psoSwarmSize = GetArgument(command_line_args, help, psoSwarmSize, DEFAULT_PSO_SWARM); } + if (o == psoInfG) {args.psoInfG = GetArgument(command_line_args, help, psoInfG, DEFAULT_PSO_G); } + if (o == psoInfL) {args.psoInfL = GetArgument(command_line_args, help, psoInfL, DEFAULT_PSO_L); } + if (o == psoInfR) {args.psoInfR = GetArgument(command_line_args, help, psoInfR, DEFAULT_PSO_R); } } const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns()); @@ -93,13 +98,46 @@ void Tuner(int argc, char* argv[]) { // Use full-search to explore all parameter combinations or random-search to search only a part of // the parameter values. The fraction is set as a command-line argument. - if (args.fraction == 1.0 || args.fraction == 0.0) { + #ifdef XGEMM_EXEC + + if(tStrategyFlag) + { + auto localtStrategy = args.tStrategy; + + if (args.fraction == 1.0 || args.fraction == 0.0) + { + localtStrategy = FULL_SEARCH_STRATEGY; + } + switch (localtStrategy) + { + case FULL_SEARCH_STRATEGY: + tuner.UseFullSearch(); + break; + + case RANDOM_SEARCH_STRATEGY: + tuner.UseRandomSearch(1.0/args.fraction); + break; + case PSO_STRATEGY: + tuner.UsePSO(1.0/args.fraction, args.psoSwarmSize, args.psoInfG, args.psoInfL, args.psoInfR); + break; + case DVDT_STRATEGY: + default: + tuner.UseFullSearch(); + } + } + + #else + + if (args.fraction == 1.0 || args.fraction == 0.0) + { tuner.UseFullSearch(); } - else { + else + { tuner.UseRandomSearch(1.0/args.fraction); } + #endif // Set extra settings for specific defines. This mimics src/routine.cc. auto defines = std::string{""}; if (isAMD && isGPU) { @@ -162,6 +200,8 @@ void Tuner(int argc, char* argv[]) { if (o == kArgBatchCount) { metadata.push_back({"arg_batch_count", ToString(args.batch_count)}); } } tuner.PrintJSON("clblast_"+C::KernelFamily()+"_"+precision_string+".json", metadata); + + } // ================================================================================================= |