summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-09-06 20:30:06 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-09-06 20:30:06 +0200
commit55038d3c919a6584e5e5891e2290c67698f3c90d (patch)
treef9a9ae2723ecd1fe4ad242b27923449e532ec4df
parenta2f83507033a20b534099c7b21d4a7466108e949 (diff)
Split GEMM tuning in two parts: a small set of tuning parameters which is explored exhaustively and a larger set which is explored randomly
-rw-r--r--src/tuning/kernels/xgemm.cpp85
1 files changed, 60 insertions, 25 deletions
diff --git a/src/tuning/kernels/xgemm.cpp b/src/tuning/kernels/xgemm.cpp
index eb7c8a66..7c9ac76a 100644
--- a/src/tuning/kernels/xgemm.cpp
+++ b/src/tuning/kernels/xgemm.cpp
@@ -7,7 +7,9 @@
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
-// This file uses the CLTune auto-tuner to tune the xgemm OpenCL kernels.
+// This file uses the CLTune auto-tuner to tune the xgemm OpenCL kernels. There are two variations:
+// - V==1: This tests some limited set of tuning parameters exhaustively.
+// - V==2: This tests a much larger set of tuning parameters by randomly sampling a subset.
//
// =================================================================================================
@@ -21,12 +23,12 @@ namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
-template <typename T>
+template <typename T, int V>
class TuneXgemm {
public:
// The representative kernel and the source code
- static std::string KernelFamily() { return "xgemm"; }
+ static std::string KernelFamily() { return (V==1) ? "xgemm_1" : "xgemm_2"; }
static std::string KernelName() { return "Xgemm"; }
static std::string GetSources() {
return
@@ -48,7 +50,7 @@ class TuneXgemm {
static size_t DefaultM() { return 1024; }
static size_t DefaultN() { return 1024; }
static size_t DefaultK() { return 1024; }
- static double DefaultFraction() { return 256.0; }
+ static double DefaultFraction() { return (V==1) ? 1.0 : 512.0; } // test all or sample randomly
// Describes how to obtain the sizes of the buffers
static size_t GetSizeX(const Arguments<T> &) { return 1; } // N/A for this kernel
@@ -60,20 +62,38 @@ class TuneXgemm {
// Sets the tuning parameters and their possible values
static void SetParameters(cltune::Tuner &tuner, const size_t id) {
- tuner.AddParameter(id, "MWG", {16, 32, 64, 128});
- tuner.AddParameter(id, "NWG", {16, 32, 64, 128});
- tuner.AddParameter(id, "KWG", {16, 32});
- tuner.AddParameter(id, "MDIMC", {8, 16, 32});
- tuner.AddParameter(id, "NDIMC", {8, 16, 32});
- tuner.AddParameter(id, "MDIMA", {8, 16, 32});
- tuner.AddParameter(id, "NDIMB", {8, 16, 32});
- tuner.AddParameter(id, "KWI", {2});
- tuner.AddParameter(id, "VWM", {1, 2, 4});
- tuner.AddParameter(id, "VWN", {1, 2, 4});
- tuner.AddParameter(id, "STRM", {0, 1});
- tuner.AddParameter(id, "STRN", {0, 1});
- tuner.AddParameter(id, "SA", {0, 1});
- tuner.AddParameter(id, "SB", {0, 1});
+ if (V==1) { // limited subset of tuning parameters - but explorable exhaustively
+ tuner.AddParameter(id, "MWG", {16, 32, 64});
+ tuner.AddParameter(id, "NWG", {16, 32, 64});
+ tuner.AddParameter(id, "KWG", {32});
+ tuner.AddParameter(id, "MDIMC", {8, 16, 32});
+ tuner.AddParameter(id, "NDIMC", {8, 16, 32});
+ tuner.AddParameter(id, "MDIMA", {8, 16, 32});
+ tuner.AddParameter(id, "NDIMB", {8, 16, 32});
+ tuner.AddParameter(id, "KWI", {2});
+ tuner.AddParameter(id, "VWM", {1, 2, 4});
+ tuner.AddParameter(id, "VWN", {1, 2, 4});
+ tuner.AddParameter(id, "STRM", {0});
+ tuner.AddParameter(id, "STRN", {0});
+ tuner.AddParameter(id, "SA", {0, 1});
+ tuner.AddParameter(id, "SB", {0, 1});
+ } // a lot more tuning parameters - has to be sampled randomly, too much to test all
+ else {
+ tuner.AddParameter(id, "MWG", {16, 32, 64, 128});
+ tuner.AddParameter(id, "NWG", {16, 32, 64, 128});
+ tuner.AddParameter(id, "KWG", {16, 32});
+ tuner.AddParameter(id, "MDIMC", {8, 16, 32});
+ tuner.AddParameter(id, "NDIMC", {8, 16, 32});
+ tuner.AddParameter(id, "MDIMA", {8, 16, 32});
+ tuner.AddParameter(id, "NDIMB", {8, 16, 32});
+ tuner.AddParameter(id, "KWI", {2});
+ tuner.AddParameter(id, "VWM", {1, 2, 4, 8});
+ tuner.AddParameter(id, "VWN", {1, 2, 4, 8});
+ tuner.AddParameter(id, "STRM", {0, 1});
+ tuner.AddParameter(id, "STRN", {0, 1});
+ tuner.AddParameter(id, "SA", {0, 1});
+ tuner.AddParameter(id, "SB", {0, 1});
+ }
}
// Sets the constraints
@@ -92,6 +112,14 @@ class TuneXgemm {
// KWG has to be a multiple of KDIMA = ((MDIMC*NDIMC)/(MDIMA)) and KDIMB = (...)
tuner.AddConstraint(id, MultipleOfXMulYDivZ, {"KWG", "MDIMC", "NDIMC", "MDIMA"});
tuner.AddConstraint(id, MultipleOfXMulYDivZ, {"KWG", "MDIMC", "NDIMC", "NDIMB"});
+
+ // Extra constraints for variation 1 to limit the set of options significantly
+ if (V==1) {
+ auto IsEqual = [] (std::vector<size_t> v) { return v[0] == v[1]; };
+ tuner.AddConstraint(id, IsEqual, {"MDIMC", "MDIMA"});
+ tuner.AddConstraint(id, IsEqual, {"NDIMC", "NDIMB"});
+ tuner.AddConstraint(id, IsEqual, {"SA", "SB"});
+ }
}
// Sets the local memory size
@@ -145,15 +173,22 @@ class TuneXgemm {
using float2 = clblast::float2;
using double2 = clblast::double2;
-// Main function (not within the clblast namespace)
-int main(int argc, char *argv[]) {
+// Function to tune a specific variation V (not within the clblast namespace)
+template <int V>
+void StartVariation(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
- case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXgemm<half>, half>(argc, argv); break;
- case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXgemm<float>, float>(argc, argv); break;
- case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXgemm<double>, double>(argc, argv); break;
- case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXgemm<float2>, float2>(argc, argv); break;
- case clblast::Precision::kComplexDouble: clblast::Tuner<clblast::TuneXgemm<double2>, double2>(argc, argv); break;
+ case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXgemm<half,V>, half>(argc, argv); break;
+ case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXgemm<float,V>, float>(argc, argv); break;
+ case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXgemm<double,V>, double>(argc, argv); break;
+ case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXgemm<float2,V>, float2>(argc, argv); break;
+ case clblast::Precision::kComplexDouble: clblast::Tuner<clblast::TuneXgemm<double2,V>, double2>(argc, argv); break;
}
+}
+
+// Main function (not within the clblast namespace)
+int main(int argc, char *argv[]) {
+ StartVariation<1>(argc, argv);
+ StartVariation<2>(argc, argv);
return 0;
}