summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/clblast.h8
-rwxr-xr-xscripts/generator/generator.py4
-rw-r--r--src/clblast.cpp32
-rw-r--r--src/routine.cpp25
-rw-r--r--src/routine.hpp10
5 files changed, 77 insertions, 2 deletions
diff --git a/include/clblast.h b/include/clblast.h
index 7b2021d8..e7b53d65 100644
--- a/include/clblast.h
+++ b/include/clblast.h
@@ -17,6 +17,8 @@
#define CLBLAST_CLBLAST_H_
#include <cstdlib> // For size_t
+#include <string> // For OverrideParameters function
+#include <unordered_map> // For OverrideParameters function
// Includes the normal OpenCL C header
#if defined(__APPLE__) || defined(__MACOSX)
@@ -617,6 +619,12 @@ StatusCode PUBLIC_API FillCache(const cl_device_id device);
// =================================================================================================
+StatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const std::string &kernel_name,
+ const Precision precision,
+ const std::unordered_map<std::string,size_t> &parameters);
+
+// =================================================================================================
+
} // namespace clblast
// CLBLAST_CLBLAST_H_
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 6591cbf7..aaf1b121 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -41,8 +41,8 @@ FILES = [
"/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp",
]
-HEADER_LINES = [117, 73, 118, 22, 29, 41, 65, 32]
-FOOTER_LINES = [17, 95, 19, 18, 6, 6, 9, 2]
+HEADER_LINES = [119, 73, 118, 22, 29, 41, 65, 32]
+FOOTER_LINES = [23, 128, 19, 18, 6, 6, 9, 2]
# Different possibilities for requirements
ald_m = "The value of `a_ld` must be at least `m`."
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 35f3f552..885b849e 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -2254,4 +2254,36 @@ StatusCode FillCache(const cl_device_id device) {
}
// =================================================================================================
+
+StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name,
+ const Precision precision,
+ const std::unordered_map<std::string,size_t> &parameters) {
+ try {
+
+ // Retrieves the device name
+ const auto device_cpp = Device(device);
+ const auto device_name = device_cpp.Name();
+
+ // Clears the existing program & binary cache for routines with the target kernel
+ const auto routine_names = Routine::routines_by_kernel.at(kernel_name);
+ for (const auto &routine_name : routine_names) {
+ ProgramCache::Instance().RemoveBySubset<1, 2>(ProgramKey{nullptr, precision, routine_name});
+ BinaryCache::Instance().Remove(BinaryKey{precision, routine_name, device_name});
+ }
+
+ // Creates a small custom database based on the provided parameters
+ const auto database_device = Database::DatabaseDevice{"default", parameters};
+ const auto database_vendor = Database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_device}};
+ const auto database_entry = Database::DatabaseEntry{kernel_name, precision, {database_vendor}};
+ const auto database = Database(device_cpp, kernel_name, precision, {&database_entry});
+
+ // Removes the old database entry and stores the new one in the cache
+ DatabaseCache::Instance().Remove(DatabaseKey{ precision, device_name, kernel_name });
+ DatabaseCache::Instance().Store(DatabaseKey{ precision, device_name, kernel_name }, Database(database));
+
+ } catch (...) { return DispatchException(); }
+ return StatusCode::kSuccess;
+}
+
+// =================================================================================================
} // namespace clblast
diff --git a/src/routine.cpp b/src/routine.cpp
index 854c7046..3cd045c8 100644
--- a/src/routine.cpp
+++ b/src/routine.cpp
@@ -21,6 +21,31 @@
namespace clblast {
// =================================================================================================
+// For each kernel this map contains a list of routines it is used in
+const std::vector<std::string> Routine::routines_axpy = {"AXPY", "COPY", "SCAL", "SWAP"};
+const std::vector<std::string> Routine::routines_dot = {"AMAX", "ASUM", "DOT", "DOTC", "DOTU", "MAX", "MIN", "NRM2", "SUM"};
+const std::vector<std::string> Routine::routines_ger = {"GER", "GERC", "GERU", "HER", "HER2", "HPR", "HPR2", "SPR", "SPR2", "SYR", "SYR2"};
+const std::vector<std::string> Routine::routines_gemv = {"GBMV", "GEMV", "HBMV", "HEMV", "HPMV", "SBMV", "SPMV", "SYMV", "TMBV", "TPMV", "TRMV"};
+const std::vector<std::string> Routine::routines_gemm = {"GEMM", "HEMM", "SYMM", "TRMM"};
+const std::vector<std::string> Routine::routines_gemm_syrk = {"GEMM", "HEMM", "HER2K", "HERK", "SYMM", "SYR2K", "SYRK", "TRMM"};
+const std::unordered_map<std::string, const std::vector<std::string>> Routine::routines_by_kernel = {
+ {"Xaxpy", routines_axpy},
+ {"Xdot", routines_dot},
+ {"Xgemv", routines_gemv},
+ {"XgemvFast", routines_gemv},
+ {"XgemvFastRot", routines_gemv},
+ {"Xgemv", {}},
+ {"Xger", routines_ger},
+ {"Copy", routines_gemm_syrk},
+ {"Pad", routines_gemm_syrk},
+ {"Transpose", routines_gemm_syrk},
+ {"Padtranspose", routines_gemm_syrk},
+ {"Xgemm", routines_gemm_syrk},
+ {"XgemmDirect", routines_gemm},
+ {"KernelSelection", routines_gemm},
+};
+// =================================================================================================
+
// The constructor does all heavy work, errors are returned as exceptions
Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
const std::vector<std::string> &kernel_names, const Precision precision,
diff --git a/src/routine.hpp b/src/routine.hpp
index ba8b9f60..622a1c0d 100644
--- a/src/routine.hpp
+++ b/src/routine.hpp
@@ -18,6 +18,7 @@
#include <string>
#include <vector>
+#include <unordered_map>
#include "utilities/utilities.hpp"
#include "cache.hpp"
@@ -42,6 +43,15 @@ class Routine {
const std::vector<const Database::DatabaseEntry*> &userDatabase,
std::initializer_list<const char *> source);
+ // List of kernel-routine look-ups
+ static const std::vector<std::string> routines_axpy;
+ static const std::vector<std::string> routines_dot;
+ static const std::vector<std::string> routines_ger;
+ static const std::vector<std::string> routines_gemv;
+ static const std::vector<std::string> routines_gemm;
+ static const std::vector<std::string> routines_gemm_syrk;
+ static const std::unordered_map<std::string, const std::vector<std::string>> routines_by_kernel;
+
private:
// Initializes program_, fetching cached program or building one