summaryrefslogtreecommitdiff
path: root/src/database.cc
blob: dc72dbddf9d2b8b8b9f441fa6470f870376672dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
//   Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the Database class (see the header for information about the class).
//
// =================================================================================================

#include "internal/database.h"
#include "internal/database/xaxpy.h"
#include "internal/database/xdot.h"
#include "internal/database/xgemv.h"
#include "internal/database/xger.h"
#include "internal/database/xgemm.h"
#include "internal/database/copy.h"
#include "internal/database/pad.h"
#include "internal/database/transpose.h"
#include "internal/database/padtranspose.h"

#include "internal/utilities.h"

namespace clblast {
// =================================================================================================

// Initializes the database
const std::vector<Database::DatabaseEntry> Database::database = {
  XaxpyHalf, XaxpySingle, XaxpyDouble, XaxpyComplexSingle, XaxpyComplexDouble,
  XdotHalf, XdotSingle, XdotDouble, XdotComplexSingle, XdotComplexDouble,
  XgemvHalf, XgemvSingle, XgemvDouble, XgemvComplexSingle, XgemvComplexDouble,
  XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble,
  XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble,
  CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble,
  PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble,
  TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble,
  PadtransposeHalf, PadtransposeSingle, PadtransposeDouble, PadtransposeComplexSingle, PadtransposeComplexDouble
};

// =================================================================================================

// Constructor, computing device properties and populating the parameter-vector from the database
Database::Database(const Queue &queue, const std::vector<std::string> &kernels,
                   const Precision precision):
  parameters_{} {

  // Finds information of the current device
  auto device = queue.GetDevice();
  auto device_type = device.Type();
  auto device_vendor = device.Vendor();
  auto device_name = device.Name();

  // Iterates over all kernels to include, and retrieves the parameters for each of them
  for (auto &kernel: kernels) {
    auto search_result = Search(kernel, device_type, device_vendor, device_name, precision);
    parameters_.insert(search_result.begin(), search_result.end());
  }
}

// =================================================================================================

// Returns a list of OpenCL pre-processor defines in string form
std::string Database::GetDefines() const {
  std::string defines{};
  for (auto &parameter: parameters_) {
    defines += "#define "+parameter.first+" "+ToString(parameter.second)+"\n";
  }
  return defines;
}

// =================================================================================================

// Searches the database for the right kernel and precision
Database::Parameters Database::Search(const std::string &this_kernel,
                                      const std::string &this_type,
                                      const std::string &this_vendor,
                                      const std::string &this_device,
                                      const Precision this_precision) const {
  // Set the short vendor name
  auto this_short_vendor = this_vendor;
  for (auto &combination : kVendorNames) {
    if (this_vendor == combination.first) {
      this_short_vendor = combination.second;
    }
  }

  // Selects the right kernel
  for (auto &db: database) {
    if (db.kernel == this_kernel && db.precision == this_precision) {

      // Searches for the right vendor and device type, or selects the default if unavailable. This
      // assumes that the default vendor / device type is last in the database.
      for (auto &vendor: db.vendors) {
        if ((vendor.name == this_short_vendor || vendor.name == kDeviceVendorAll) &&
            (vendor.type == this_type || vendor.type == kDeviceTypeAll)) {

          // Searches for the right device. If the current device is unavailable, selects the vendor
          // default parameters. This assumes the default is last in the database.
          for (auto &device: vendor.devices) {

            if (device.name == this_device || device.name == "default") {

              // Sets the parameters accordingly
              return device.parameters;
            }
          }
        }
      }
    }
  }

  // If we reached this point, something is wrong
  throw std::runtime_error("Database error, could not find a suitable entry");
}

// =================================================================================================
} // namespace clblast