diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-08-21 20:06:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-08-21 20:06:29 +0200 |
commit | e5eb6b1d3a66358093cb40f0fad51ecdc4654771 (patch) | |
tree | e5eb03736240ec07534319fdd15661e1093f04ac /src/tuning/tuning.hpp | |
parent | d67fd6604b4a6584c4f9e856057fcc8076ce377d (diff) | |
parent | dfd332524ab0e66a04d803bb075787e35cd2db1a (diff) |
Merge pull request #173 from mcian/PSO_params
Add PSO parameters support and search strategy selection from command…
Diffstat (limited to 'src/tuning/tuning.hpp')
-rw-r--r-- | src/tuning/tuning.hpp | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/src/tuning/tuning.hpp b/src/tuning/tuning.hpp index 25504430..6a8039d2 100644 --- a/src/tuning/tuning.hpp +++ b/src/tuning/tuning.hpp @@ -48,6 +48,12 @@ void Tuner(int argc, char* argv[]) { if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); } if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); } if (o == kArgBatchCount) { args.batch_count = GetArgument(command_line_args, help, kArgBatchCount, C::DefaultBatchCount()); } + if (o == kArgHeuristicSelection) {args.heuristic_selection = GetArgument(command_line_args, help, kArgHeuristicSelection, C::DefaultHeuristic()); } + if (o == kArgPsoSwarmSize) {args.pso_swarm_size = GetArgument(command_line_args, help, kArgPsoSwarmSize , C::DefaultSwarmSizePSO()); } + if (o == kArgPsoInfGlobal) {args.pso_inf_global = GetArgument(command_line_args, help, kArgPsoInfGlobal, C::DefaultInfluenceGlobalPSO()); } + if (o == kArgPsoInfLocal) {args.pso_inf_local = GetArgument(command_line_args, help, kArgPsoInfLocal, C::DefaultInfluenceLocalPSO()); } + if (o == kArgPsoInfRandom) {args.pso_inf_random = GetArgument(command_line_args, help, kArgPsoInfRandom, C::DefaultInfluenceRandomPSO()); } + if (o == kArgAnnMaxTemp) {args.ann_max_temperature = GetArgument(command_line_args, help, kArgAnnMaxTemp, C::DefaultMaxTempAnn());} } const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns()); @@ -91,14 +97,16 @@ void Tuner(int argc, char* argv[]) { // Initializes the tuner for the chosen device cltune::Tuner tuner(args.platform_id, args.device_id); - // Use full-search to explore all parameter combinations or random-search to search only a part of - // the parameter values. The fraction is set as a command-line argument. - if (args.fraction == 1.0 || args.fraction == 0.0) { - tuner.UseFullSearch(); - } - else { - tuner.UseRandomSearch(1.0/args.fraction); + // Select the search method based on the cmd_line arguments + // If the tuner does not support the selected choice, Full Search will be returned. + auto method = C::GetHeuristic(args); + + if (method == 1) { tuner.UseRandomSearch(1.0/args.fraction); } + else if (method == 2) { tuner.UseAnnealing(args.fraction, args.ann_max_temperature); } + else if (method == 3) { + tuner.UsePSO(args.fraction, args.pso_swarm_size, args.pso_inf_global, args.pso_inf_local, args.pso_inf_random); } + else { tuner.UseFullSearch(); } // Set extra settings for specific defines. This mimics src/routine.cc. auto defines = std::string{""}; @@ -162,6 +170,7 @@ void Tuner(int argc, char* argv[]) { if (o == kArgBatchCount) { metadata.push_back({"arg_batch_count", ToString(args.batch_count)}); } } tuner.PrintJSON("clblast_"+C::KernelFamily()+"_"+precision_string+".json", metadata); + } // ================================================================================================= |