summaryrefslogtreecommitdiff
path: root/src/database/database.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/database.cpp')
-rw-r--r--src/database/database.cpp53
1 files changed, 34 insertions, 19 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp
index 34c44a29..2340a89c 100644
--- a/src/database/database.cpp
+++ b/src/database/database.cpp
@@ -21,27 +21,42 @@
#include "database/kernels/xgemv_fast_rot.hpp"
#include "database/kernels/xger.hpp"
#include "database/kernels/xgemm.hpp"
+#include "database/kernels/xgemm_direct.hpp"
#include "database/kernels/copy.hpp"
#include "database/kernels/pad.hpp"
#include "database/kernels/transpose.hpp"
#include "database/kernels/padtranspose.hpp"
+#include "database/kernel_selection.hpp"
namespace clblast {
// =================================================================================================
// Initializes the database
-const std::vector<Database::DatabaseEntry> Database::database = {
- XaxpyHalf, XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble,
- XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble,
- XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble,
- XgemvFastHalf, XgemvFastSingle, XgemvFastDouble, XgemvFastComplexSingle, XgemvFastComplexDouble,
- XgemvFastRotHalf, XgemvFastRotSingle, XgemvFastRotDouble, XgemvFastRotComplexSingle, XgemvFastRotComplexDouble,
- XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble,
- XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble,
- CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble,
- PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble,
- TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble,
- PadtransposeHalf, PadtransposeSingle, PadtransposeDouble, PadtransposeComplexSingle, PadtransposeComplexDouble
+const std::vector<const Database::DatabaseEntry*> Database::database = {
+ &database::XaxpyHalf, &database::XaxpySingle, &database::XaxpyDouble, &database::XaxpyComplexSingle, &database::XaxpyComplexDouble,
+ &database::XdotHalf, &database::XdotSingle, &database::XdotDouble, &database::XdotComplexSingle, &database::XdotComplexDouble,
+ &database::XgemvHalf, &database::XgemvSingle, &database::XgemvDouble, &database::XgemvComplexSingle, &database::XgemvComplexDouble,
+ &database::XgemvFastHalf, &database::XgemvFastSingle, &database::XgemvFastDouble, &database::XgemvFastComplexSingle, &database::XgemvFastComplexDouble,
+ &database::XgemvFastRotHalf, &database::XgemvFastRotSingle, &database::XgemvFastRotDouble, &database::XgemvFastRotComplexSingle, &database::XgemvFastRotComplexDouble,
+ &database::XgerHalf, &database::XgerSingle, &database::XgerDouble, &database::XgerComplexSingle, &database::XgerComplexDouble,
+ &database::XgemmHalf, &database::XgemmSingle, &database::XgemmDouble, &database::XgemmComplexSingle, &database::XgemmComplexDouble,
+ &database::XgemmDirectHalf, &database::XgemmDirectSingle, &database::XgemmDirectDouble, &database::XgemmDirectComplexSingle, &database::XgemmDirectComplexDouble,
+ &database::CopyHalf, &database::CopySingle, &database::CopyDouble, &database::CopyComplexSingle, &database::CopyComplexDouble,
+ &database::PadHalf, &database::PadSingle, &database::PadDouble, &database::PadComplexSingle, &database::PadComplexDouble,
+ &database::TransposeHalf, &database::TransposeSingle, &database::TransposeDouble, &database::TransposeComplexSingle, &database::TransposeComplexDouble,
+ &database::PadtransposeHalf, &database::PadtransposeSingle, &database::PadtransposeDouble, &database::PadtransposeComplexSingle, &database::PadtransposeComplexDouble,
+ &database::KernelSelectionHalf, &database::KernelSelectionSingle, &database::KernelSelectionDouble, &database::KernelSelectionComplexSingle, &database::KernelSelectionComplexDouble
+};
+
+// The OpenCL device vendors
+const std::string Database::kDeviceVendorAll = "default";
+
+// Alternative names for some OpenCL vendors
+const std::unordered_map<std::string, std::string> Database::kVendorNames{
+ { "Intel(R) Corporation", "Intel" },
+ { "GenuineIntel", "Intel" },
+ { "Advanced Micro Devices, Inc.", "AMD" },
+ { "NVIDIA Corporation", "NVIDIA" },
};
// =================================================================================================
@@ -49,7 +64,7 @@ const std::vector<Database::DatabaseEntry> Database::database = {
// Constructor, computing device properties and populating the parameter-vector from the database.
// This takes an optional overlay database in case of custom tuning or custom kernels.
Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
- const Precision precision, const std::vector<DatabaseEntry> &overlay):
+ const Precision precision, const std::vector<const DatabaseEntry*> &overlay):
parameters_{} {
// Finds information of the current device
@@ -69,8 +84,8 @@ Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
for (auto &kernel: kernels) {
auto search_result = ParametersPtr{};
- for (auto db: { &overlay, &database }) {
- search_result = Search(kernel, device_type, device_vendor, device_name, precision, *db);
+ for (auto &db: { database, overlay}) {
+ search_result = Search(kernel, device_type, device_vendor, device_name, precision, db);
if (search_result) {
parameters_.insert(search_result->begin(), search_result->end());
break;
@@ -100,17 +115,17 @@ Database::ParametersPtr Database::Search(const std::string &this_kernel,
const std::string &this_vendor,
const std::string &this_device,
const Precision this_precision,
- const std::vector<DatabaseEntry> &this_database) const {
+ const std::vector<const DatabaseEntry*> &this_database) const {
// Selects the right kernel
for (auto &db: this_database) {
- if (db.kernel == this_kernel && db.precision == this_precision) {
+ if (db->kernel == this_kernel && db->precision == this_precision) {
// Searches for the right vendor and device type, or selects the default if unavailable. This
// assumes that the default vendor / device type is last in the database.
- for (auto &vendor: db.vendors) {
+ for (auto &vendor: db->vendors) {
if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) &&
- (vendor.type == this_type || vendor.type == kDeviceTypeAll)) {
+ (vendor.type == this_type || vendor.type == database::kDeviceTypeAll)) {
// Searches for the right device. If the current device is unavailable, selects the vendor
// default parameters. This assumes the default is last in the database.