summaryrefslogtreecommitdiff
path: root/src/tuning
diff options
context:
space:
mode:
Diffstat (limited to 'src/tuning')
-rw-r--r--src/tuning/tuning_api.cpp42
1 files changed, 42 insertions, 0 deletions
diff --git a/src/tuning/tuning_api.cpp b/src/tuning/tuning_api.cpp
index 61cb0389..4ffb46c2 100644
--- a/src/tuning/tuning_api.cpp
+++ b/src/tuning/tuning_api.cpp
@@ -91,6 +91,34 @@ template StatusCode TuneXger<float2>(RawCommandQueue*, const size_t, const size_
template StatusCode TuneXger<double2>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template <typename T>
+StatusCode TuneXgemm(RawCommandQueue * queue, const size_t m, const size_t n, const size_t k,
+ const double fraction, std::unordered_map<std::string,size_t> &parameters) {
+ auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n; args.k = k;
+ auto queue_cpp = Queue(*queue);
+ return TunerAPI<T>(queue_cpp, args, 2, GetTunerDefaults, GetTunerSettings<T>,
+ TestValidArguments<T>, SetConstraints, SetArguments<T>, parameters);
+}
+template StatusCode TuneXgemm<half>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneXgemm<float>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneXgemm<double>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneXgemm<float2>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneXgemm<double2>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+
+template <typename T>
+StatusCode TuneXgemmDirect(RawCommandQueue * queue, const size_t m, const size_t n, const size_t k,
+ const double fraction, std::unordered_map<std::string,size_t> &parameters) {
+ auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n; args.k = k;
+ auto queue_cpp = Queue(*queue);
+ return TunerAPI<T>(queue_cpp, args, 2, GetTunerDefaults, GetTunerSettings<T>,
+ TestValidArguments<T>, SetConstraints, SetArguments<T>, parameters);
+}
+template StatusCode TuneXgemmDirect<half>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneXgemmDirect<float>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneXgemmDirect<double>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneXgemmDirect<float2>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneXgemmDirect<double2>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+
+template <typename T>
StatusCode TuneCopy(RawCommandQueue * queue, const size_t m, const size_t n,
const double fraction, std::unordered_map<std::string,size_t> &parameters) {
auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n;
@@ -146,6 +174,20 @@ template StatusCode TunePadtranspose<double>(RawCommandQueue*, const size_t, con
template StatusCode TunePadtranspose<float2>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
template StatusCode TunePadtranspose<double2>(RawCommandQueue*, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template <typename T>
+StatusCode TuneInvert(RawCommandQueue * queue, const size_t m, const size_t n, const size_t k,
+ const double fraction, std::unordered_map<std::string,size_t> &parameters) {
+ auto args = Arguments<T>(); args.fraction = fraction; args.m = m; args.n = n; args.k = k;
+ auto queue_cpp = Queue(*queue);
+ return TunerAPI<T>(queue_cpp, args, 0, GetTunerDefaults, GetTunerSettings<T>,
+ TestValidArguments<T>, SetConstraints, SetArguments<T>, parameters);
+}
+template StatusCode TuneInvert<half>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneInvert<float>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneInvert<double>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneInvert<float2>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+template StatusCode TuneInvert<double2>(RawCommandQueue*, const size_t, const size_t, const size_t, const double, std::unordered_map<std::string,size_t>&);
+
// =================================================================================================
// The main tuner API, similar to the one in tuning.cpp, but without I/O