summaryrefslogtreecommitdiff
path: root/src/tuning/tuning.hpp
diff options
context:
space:
mode:
authormcian <mcian86@gmail.com>2017-07-23 14:48:13 +0200
committermcian <mcian86@gmail.com>2017-07-23 14:48:13 +0200
commit473e81471895b35dcec5cb82e6beba134c544006 (patch)
treee4ff6df062b45644bc0ca8c0fb7640864128ad2f /src/tuning/tuning.hpp
parent8131e68664e02c8a1bc5a0f5598294fd3bc5b974 (diff)
Code refactoring
Diffstat (limited to 'src/tuning/tuning.hpp')
-rw-r--r--src/tuning/tuning.hpp60
1 files changed, 15 insertions, 45 deletions
diff --git a/src/tuning/tuning.hpp b/src/tuning/tuning.hpp
index 35b320cb..2e0eb5a1 100644
--- a/src/tuning/tuning.hpp
+++ b/src/tuning/tuning.hpp
@@ -48,11 +48,13 @@ void Tuner(int argc, char* argv[]) {
if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); }
if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); }
if (o == kArgBatchCount) { args.batch_count = GetArgument(command_line_args, help, kArgBatchCount, C::DefaultBatchCount()); }
- if (o == tStrategy) {args.tStrategy = GetArgument(command_line_args, help, tStrategy, DEFAULT_STRATEGY); }
- if (o == psoSwarmSize) {args.psoSwarmSize = GetArgument(command_line_args, help, psoSwarmSize, DEFAULT_PSO_SWARM); }
- if (o == psoInfG) {args.psoInfG = GetArgument(command_line_args, help, psoInfG, DEFAULT_PSO_G); }
- if (o == psoInfL) {args.psoInfL = GetArgument(command_line_args, help, psoInfL, DEFAULT_PSO_L); }
- if (o == psoInfR) {args.psoInfR = GetArgument(command_line_args, help, psoInfR, DEFAULT_PSO_R); }
+ if (o == kArgHeuristicSelection) {args.heuristic_selection = GetArgument(command_line_args, help, kArgHeuristicSelection, C::DefaultHeuristic()); }
+ if (o == kArgMultiSearchStrategy) {args.multi_search_strategy = GetArgument(command_line_args, help, kArgMultiSearchStrategy, 0);}
+ if (o == kArgPsoSwarmSize) {args.pso_swarm_size = GetArgument(command_line_args, help, kArgPsoSwarmSize , C::DefaultSwarmSizePSO()); }
+ if (o == kArgPsoInfGlobal) {args.pso_inf_global = GetArgument(command_line_args, help, kArgPsoInfGlobal, C::DefaultInfluenceGlobalPSO()); }
+ if (o == kArgPsoInfLocal) {args.pso_inf_local = GetArgument(command_line_args, help, kArgPsoInfLocal, C::DefaultInfluenceLocalPSO()); }
+ if (o == kArgPsoInfRandom) {args.pso_inf_random = GetArgument(command_line_args, help, kArgPsoInfRandom, C::DefaultInfluenceRandomPSO()); }
+ if (o == kArgAnnMaxTemp) {args.ann_max_temperature = GetArgument(command_line_args, help, kArgAnnMaxTemp, C::DefaultMaxTempAnn());}
}
const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns());
@@ -96,48 +98,17 @@ void Tuner(int argc, char* argv[]) {
// Initializes the tuner for the chosen device
cltune::Tuner tuner(args.platform_id, args.device_id);
- // Use full-search to explore all parameter combinations or random-search to search only a part of
- // the parameter values. The fraction is set as a command-line argument.
- #ifdef XGEMM_EXEC
+ // Select the search method based on the cmd_line arguments
+ // If the tuner does not support the selected choice, Full Search will be returned.
+ auto method = C::GetCurrentHeuristic(args);
- if(tStrategyFlag)
- {
- auto localtStrategy = args.tStrategy;
-
- if (args.fraction == 1.0 || args.fraction == 0.0)
- {
- localtStrategy = FULL_SEARCH_STRATEGY;
- }
- switch (localtStrategy)
- {
- case FULL_SEARCH_STRATEGY:
- tuner.UseFullSearch();
- break;
-
- case RANDOM_SEARCH_STRATEGY:
- tuner.UseRandomSearch(1.0/args.fraction);
- break;
- case PSO_STRATEGY:
- tuner.UsePSO(1.0/args.fraction, args.psoSwarmSize, args.psoInfG, args.psoInfL, args.psoInfR);
- break;
- case DVDT_STRATEGY:
- default:
- tuner.UseFullSearch();
- }
+ if (method == 1) { tuner.UseRandomSearch(1.0/args.fraction); }
+ else if (method == 2) { tuner.UseAnnealing(args.fraction, args.ann_max_temperature); }
+ else if (method == 3) {
+ tuner.UsePSO(args.fraction, args.pso_swarm_size, args.pso_inf_global, args.pso_inf_local, args.pso_inf_random);
}
+ else { tuner.UseFullSearch(); }
- #else
-
- if (args.fraction == 1.0 || args.fraction == 0.0)
- {
- tuner.UseFullSearch();
- }
- else
- {
- tuner.UseRandomSearch(1.0/args.fraction);
- }
-
- #endif
// Set extra settings for specific defines. This mimics src/routine.cc.
auto defines = std::string{""};
if (isAMD && isGPU) {
@@ -201,7 +172,6 @@ void Tuner(int argc, char* argv[]) {
}
tuner.PrintJSON("clblast_"+C::KernelFamily()+"_"+precision_string+".json", metadata);
-
}
// =================================================================================================