diff options
Diffstat (limited to 'src/database/database.cpp')
-rw-r--r-- | src/database/database.cpp | 61 |
1 files changed, 38 insertions, 23 deletions
diff --git a/src/database/database.cpp b/src/database/database.cpp index 6ec93731..38974b95 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -17,6 +17,8 @@ #include "database/kernels/xaxpy.hpp" #include "database/kernels/xdot.hpp" #include "database/kernels/xgemv.hpp" +#include "database/kernels/xgemv_fast.hpp" +#include "database/kernels/xgemv_fast_rot.hpp" #include "database/kernels/xger.hpp" #include "database/kernels/xgemm.hpp" #include "database/kernels/copy.hpp" @@ -32,8 +34,10 @@ const std::vector<Database::DatabaseEntry> Database::database = { XaxpyHalf, XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble, XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble, XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble, + XgemvFastHalf, XgemvFastSingle, XgemvFastDouble, XgemvFastComplexSingle, XgemvFastComplexDouble, + /* XgemvFastRotHalf, */ XgemvFastRotSingle, XgemvFastRotDouble, XgemvFastRotComplexSingle, XgemvFastRotComplexDouble, XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble, - XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble, + /* XgemmHalf, */ XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble, CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble, PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble, TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble, @@ -42,9 +46,10 @@ const std::vector<Database::DatabaseEntry> Database::database = { // ================================================================================================= -// Constructor, computing device properties and populating the parameter-vector from the database +// Constructor, computing device properties and populating the parameter-vector from the database. +// This takes an optional overlay database in case of custom tuning or custom kernels. Database::Database(const Queue &queue, const std::vector<std::string> &kernels, - const Precision precision): + const Precision precision, const std::vector<DatabaseEntry> &overlay): parameters_{} { // Finds information of the current device @@ -53,10 +58,26 @@ Database::Database(const Queue &queue, const std::vector<std::string> &kernels, auto device_vendor = device.Vendor(); auto device_name = device.Name(); + // Set the short vendor name + for (auto &combination : kVendorNames) { + if (device_vendor == combination.first) { + device_vendor = combination.second; + } + } + // Iterates over all kernels to include, and retrieves the parameters for each of them for (auto &kernel: kernels) { - auto search_result = Search(kernel, device_type, device_vendor, device_name, precision); - parameters_.insert(search_result.begin(), search_result.end()); + auto search_result = ParametersPtr{}; + + for (auto db: { &overlay, &database }) { + search_result = Search(kernel, device_type, device_vendor, device_name, precision, *db); + if (search_result) { + parameters_.insert(search_result->begin(), search_result->end()); + break; + } + } + + if (!search_result) { throw std::runtime_error("Database error, could not find a suitable entry"); } } } @@ -73,28 +94,22 @@ std::string Database::GetDefines() const { // ================================================================================================= -// Searches the database for the right kernel and precision -Database::Parameters Database::Search(const std::string &this_kernel, - const std::string &this_type, - const std::string &this_vendor, - const std::string &this_device, - const Precision this_precision) const { - // Set the short vendor name - auto this_short_vendor = this_vendor; - for (auto &combination : kVendorNames) { - if (this_vendor == combination.first) { - this_short_vendor = combination.second; - } - } +// Searches a particular database for the right kernel and precision +Database::ParametersPtr Database::Search(const std::string &this_kernel, + const std::string &this_type, + const std::string &this_vendor, + const std::string &this_device, + const Precision this_precision, + const std::vector<DatabaseEntry> &this_database) const { // Selects the right kernel - for (auto &db: database) { + for (auto &db: this_database) { if (db.kernel == this_kernel && db.precision == this_precision) { // Searches for the right vendor and device type, or selects the default if unavailable. This // assumes that the default vendor / device type is last in the database. for (auto &vendor: db.vendors) { - if ((vendor.name == this_short_vendor || vendor.name == kDeviceVendorAll) && + if ((vendor.name == this_vendor || vendor.name == kDeviceVendorAll) && (vendor.type == this_type || vendor.type == kDeviceTypeAll)) { // Searches for the right device. If the current device is unavailable, selects the vendor @@ -104,7 +119,7 @@ Database::Parameters Database::Search(const std::string &this_kernel, if (device.name == this_device || device.name == "default") { // Sets the parameters accordingly - return device.parameters; + return &device.parameters; } } } @@ -112,8 +127,8 @@ Database::Parameters Database::Search(const std::string &this_kernel, } } - // If we reached this point, something is wrong - throw std::runtime_error("Database error, could not find a suitable entry"); + // If we reached this point, the entry was not found in this database + return nullptr; } // ================================================================================================= |