summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-10-08 21:52:02 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-10-08 21:52:02 +0200
commitdf3c9f4a8ab9e82ccc4add15b04da5c1b6172b72 (patch)
tree834d7f481d0370ce5e467e940396ee7c3738eb3c
parent2bb8402ec15a672eaa26595247aa09f7d88fecdb (diff)
Moved non-routine-specific API functions and includes to separate files
-rw-r--r--CMakeLists.txt2
-rwxr-xr-xscripts/generator/generator.py4
-rw-r--r--src/api_common.cpp169
-rw-r--r--src/clblast.cpp207
-rw-r--r--src/routines/routines.hpp76
5 files changed, 250 insertions, 208 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 05e7393b..52accbd4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -182,6 +182,7 @@ set(SOURCES
src/routines/common.cpp
src/utilities/clblast_exceptions.cpp
src/utilities/utilities.cpp
+ src/api_common.cpp
src/cache.cpp
src/clblast.cpp
src/clblast_c.cpp
@@ -201,6 +202,7 @@ set(HEADERS # such that they can be discovered by IDEs such as CLion and Visual
src/routines/level1/xmin.hpp
src/routines/level1/xsum.hpp
src/routines/common.hpp
+ src/routines/routines.hpp
src/utilities/buffer_test.hpp
src/utilities/clblast_exceptions.hpp
src/utilities/device_mapping.hpp
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index df0eaca0..0d34d7fe 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -42,8 +42,8 @@ FILES = [
"/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp",
]
-HEADER_LINES = [122, 79, 126, 24, 29, 41, 29, 65, 32]
-FOOTER_LINES = [25, 147, 27, 38, 6, 6, 6, 9, 2]
+HEADER_LINES = [122, 21, 126, 24, 29, 41, 29, 65, 32]
+FOOTER_LINES = [25, 3, 27, 38, 6, 6, 6, 9, 2]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
diff --git a/src/api_common.cpp b/src/api_common.cpp
new file mode 100644
index 00000000..aa7e2b0f
--- /dev/null
+++ b/src/api_common.cpp
@@ -0,0 +1,169 @@
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the common (non-OpenCL-specific) functions of the CLBlast API.
+//
+// =================================================================================================
+
+#include <string>
+
+#include "cache.hpp"
+#include "routines/routines.hpp"
+#include "clblast.h"
+
+namespace clblast {
+// =================================================================================================
+
+// Clears the cache of stored binaries
+StatusCode ClearCache() {
+ try {
+ ProgramCache::Instance().Invalidate();
+ BinaryCache::Instance().Invalidate();
+ } catch (...) { return DispatchException(); }
+ return StatusCode::kSuccess;
+}
+
+template <typename Real, typename Complex>
+void FillCacheForPrecision(Queue &queue) {
+ try {
+
+ // Runs all the level 1 set-up functions
+ Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
+ Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
+ Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr);
+ Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr);
+ Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr);
+ Xdot<Real>(queue, nullptr);
+ Xdotu<Complex>(queue, nullptr);
+ Xdotc<Complex>(queue, nullptr);
+ Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr);
+ Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr);
+ Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr);
+ Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr);
+ Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr);
+ Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr);
+
+ // Runs all the level 2 set-up functions
+ Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr);
+ Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr);
+ Xhemv<Complex>(queue, nullptr);
+ Xhbmv<Complex>(queue, nullptr);
+ Xhpmv<Complex>(queue, nullptr);
+ Xsymv<Real>(queue, nullptr);
+ Xsbmv<Real>(queue, nullptr);
+ Xspmv<Real>(queue, nullptr);
+ Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr);
+ Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr);
+ Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr);
+ Xger<Real>(queue, nullptr);
+ Xgeru<Complex>(queue, nullptr);
+ Xgerc<Complex>(queue, nullptr);
+ Xher<Complex,Real>(queue, nullptr);
+ Xhpr<Complex,Real>(queue, nullptr);
+ Xher2<Complex>(queue, nullptr);
+ Xhpr2<Complex>(queue, nullptr);
+ Xsyr<Real>(queue, nullptr);
+ Xspr<Real>(queue, nullptr);
+ Xsyr2<Real>(queue, nullptr);
+ Xspr2<Real>(queue, nullptr);
+
+ // Runs all the level 3 set-up functions
+ Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr);
+ Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr);
+ Xhemm<Complex>(queue, nullptr);
+ Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr);
+ Xherk<Complex,Real>(queue, nullptr);
+ Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr);
+ Xher2k<Complex,Real>(queue, nullptr);
+ Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr);
+
+ // Runs all the non-BLAS set-up functions
+ Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr);
+
+ } catch(const RuntimeErrorCode &e) {
+ if (e.status() != StatusCode::kNoDoublePrecision &&
+ e.status() != StatusCode::kNoHalfPrecision) {
+ throw;
+ }
+ }
+}
+
+// Fills the cache with all binaries for a specific device
+// TODO: Add half-precision FP16 set-up calls
+StatusCode FillCache(const RawDeviceID device) {
+ try {
+
+ // Creates a sample context and queue to match the normal routine calling conventions
+ auto device_cpp = Device(device);
+ auto context = Context(device_cpp);
+ auto queue = Queue(context, device_cpp);
+
+ FillCacheForPrecision<float, float2>(queue);
+ FillCacheForPrecision<double, double2>(queue);
+
+ } catch (...) { return DispatchException(); }
+ return StatusCode::kSuccess;
+}
+
+// =================================================================================================
+
+// Overrides the tuning parameters for this device-precision-kernel combination
+StatusCode OverrideParameters(const RawDeviceID device, const std::string &kernel_name,
+ const Precision precision,
+ const std::unordered_map<std::string,size_t> &parameters) {
+ try {
+
+ // Retrieves the device name
+ const auto device_cpp = Device(device);
+ const auto platform_id = device_cpp.PlatformID();
+ const auto device_name = GetDeviceName(device_cpp);
+
+ // Retrieves the current database values to verify whether the new ones are complete
+ auto in_cache = false;
+ auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache);
+ if (!in_cache) {
+ log_debug("Searching database for kernel '" + kernel_name + "'");
+ current_database = Database(device_cpp, kernel_name, precision, {});
+ }
+
+ // Verifies the parameters size
+ const auto current_parameter_names = current_database.GetParameterNames();
+ if (current_parameter_names.size() != parameters.size()) {
+ return StatusCode::kMissingOverrideParameter;
+ }
+
+ // Retrieves the names and values separately and in the same order as the existing database
+ auto parameter_values = database::Params{0};
+ auto i = size_t{0};
+ for (const auto &current_param : current_parameter_names) {
+ if (parameters.find(current_param) == parameters.end()) {
+ return StatusCode::kMissingOverrideParameter;
+ }
+ const auto parameter_value = parameters.at(current_param);
+ parameter_values[i] = parameter_value;
+ ++i;
+ }
+
+ // Creates a small custom database based on the provided parameters
+ const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values};
+ const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}};
+ const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}};
+ const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}};
+ const auto database_entries = std::vector<database::DatabaseEntry>{database_entry};
+ const auto database = Database(device_cpp, kernel_name, precision, database_entries);
+
+ // Removes the old database entry and stores the new one in the cache
+ DatabaseCache::Instance().Remove(DatabaseKey{platform_id, device, precision, kernel_name});
+ DatabaseCache::Instance().Store(DatabaseKey{platform_id, device, precision, kernel_name}, Database(database));
+
+ } catch (...) { return DispatchException(); }
+ return StatusCode::kSuccess;
+}
+
+// =================================================================================================
+} // namespace clblast
diff --git a/src/clblast.cpp b/src/clblast.cpp
index 9f865a23..7d2c2cef 100644
--- a/src/clblast.cpp
+++ b/src/clblast.cpp
@@ -15,67 +15,9 @@
#include <string>
-#include "cache.hpp"
+#include "routines/routines.hpp"
#include "clblast.h"
-// BLAS level-1 includes
-#include "routines/level1/xswap.hpp"
-#include "routines/level1/xscal.hpp"
-#include "routines/level1/xcopy.hpp"
-#include "routines/level1/xaxpy.hpp"
-#include "routines/level1/xdot.hpp"
-#include "routines/level1/xdotu.hpp"
-#include "routines/level1/xdotc.hpp"
-#include "routines/level1/xnrm2.hpp"
-#include "routines/level1/xasum.hpp"
-#include "routines/level1/xsum.hpp" // non-BLAS routine
-#include "routines/level1/xamax.hpp"
-#include "routines/level1/xamin.hpp" // non-BLAS routine
-#include "routines/level1/xmax.hpp" // non-BLAS routine
-#include "routines/level1/xmin.hpp" // non-BLAS routine
-
-// BLAS level-2 includes
-#include "routines/level2/xgemv.hpp"
-#include "routines/level2/xgbmv.hpp"
-#include "routines/level2/xhemv.hpp"
-#include "routines/level2/xhbmv.hpp"
-#include "routines/level2/xhpmv.hpp"
-#include "routines/level2/xsymv.hpp"
-#include "routines/level2/xsbmv.hpp"
-#include "routines/level2/xspmv.hpp"
-#include "routines/level2/xtrmv.hpp"
-#include "routines/level2/xtbmv.hpp"
-#include "routines/level2/xtpmv.hpp"
-#include "routines/level2/xtrsv.hpp"
-#include "routines/level2/xger.hpp"
-#include "routines/level2/xgeru.hpp"
-#include "routines/level2/xgerc.hpp"
-#include "routines/level2/xher.hpp"
-#include "routines/level2/xhpr.hpp"
-#include "routines/level2/xher2.hpp"
-#include "routines/level2/xhpr2.hpp"
-#include "routines/level2/xsyr.hpp"
-#include "routines/level2/xspr.hpp"
-#include "routines/level2/xsyr2.hpp"
-#include "routines/level2/xspr2.hpp"
-
-// BLAS level-3 includes
-#include "routines/level3/xgemm.hpp"
-#include "routines/level3/xsymm.hpp"
-#include "routines/level3/xhemm.hpp"
-#include "routines/level3/xsyrk.hpp"
-#include "routines/level3/xherk.hpp"
-#include "routines/level3/xsyr2k.hpp"
-#include "routines/level3/xher2k.hpp"
-#include "routines/level3/xtrmm.hpp"
-#include "routines/level3/xtrsm.hpp"
-
-// Level-x includes (non-BLAS)
-#include "routines/levelx/xomatcopy.hpp"
-#include "routines/levelx/xim2col.hpp"
-#include "routines/levelx/xaxpybatched.hpp"
-#include "routines/levelx/xgemmbatched.hpp"
-
namespace clblast {
// =================================================================================================
@@ -2389,153 +2331,6 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
-// =================================================================================================
-
-// Clears the cache of stored binaries
-StatusCode ClearCache() {
- try {
- ProgramCache::Instance().Invalidate();
- BinaryCache::Instance().Invalidate();
- } catch (...) { return DispatchException(); }
- return StatusCode::kSuccess;
-}
-
-template <typename Real, typename Complex>
-void FillCacheForPrecision(Queue &queue) {
- try {
-
- // Runs all the level 1 set-up functions
- Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
- Xswap<Real>(queue, nullptr); Xswap<Complex>(queue, nullptr);
- Xscal<Real>(queue, nullptr); Xscal<Complex>(queue, nullptr);
- Xcopy<Real>(queue, nullptr); Xcopy<Complex>(queue, nullptr);
- Xaxpy<Real>(queue, nullptr); Xaxpy<Complex>(queue, nullptr);
- Xdot<Real>(queue, nullptr);
- Xdotu<Complex>(queue, nullptr);
- Xdotc<Complex>(queue, nullptr);
- Xnrm2<Real>(queue, nullptr); Xnrm2<Complex>(queue, nullptr);
- Xasum<Real>(queue, nullptr); Xasum<Complex>(queue, nullptr);
- Xsum<Real>(queue, nullptr); Xsum<Complex>(queue, nullptr);
- Xamax<Real>(queue, nullptr); Xamax<Complex>(queue, nullptr);
- Xmax<Real>(queue, nullptr); Xmax<Complex>(queue, nullptr);
- Xmin<Real>(queue, nullptr); Xmin<Complex>(queue, nullptr);
-
- // Runs all the level 2 set-up functions
- Xgemv<Real>(queue, nullptr); Xgemv<Complex>(queue, nullptr);
- Xgbmv<Real>(queue, nullptr); Xgbmv<Complex>(queue, nullptr);
- Xhemv<Complex>(queue, nullptr);
- Xhbmv<Complex>(queue, nullptr);
- Xhpmv<Complex>(queue, nullptr);
- Xsymv<Real>(queue, nullptr);
- Xsbmv<Real>(queue, nullptr);
- Xspmv<Real>(queue, nullptr);
- Xtrmv<Real>(queue, nullptr); Xtrmv<Complex>(queue, nullptr);
- Xtbmv<Real>(queue, nullptr); Xtbmv<Complex>(queue, nullptr);
- Xtpmv<Real>(queue, nullptr); Xtpmv<Complex>(queue, nullptr);
- Xger<Real>(queue, nullptr);
- Xgeru<Complex>(queue, nullptr);
- Xgerc<Complex>(queue, nullptr);
- Xher<Complex,Real>(queue, nullptr);
- Xhpr<Complex,Real>(queue, nullptr);
- Xher2<Complex>(queue, nullptr);
- Xhpr2<Complex>(queue, nullptr);
- Xsyr<Real>(queue, nullptr);
- Xspr<Real>(queue, nullptr);
- Xsyr2<Real>(queue, nullptr);
- Xspr2<Real>(queue, nullptr);
-
- // Runs all the level 3 set-up functions
- Xgemm<Real>(queue, nullptr); Xgemm<Complex>(queue, nullptr);
- Xsymm<Real>(queue, nullptr); Xsymm<Complex>(queue, nullptr);
- Xhemm<Complex>(queue, nullptr);
- Xsyrk<Real>(queue, nullptr); Xsyrk<Complex>(queue, nullptr);
- Xherk<Complex,Real>(queue, nullptr);
- Xsyr2k<Real>(queue, nullptr); Xsyr2k<Complex>(queue, nullptr);
- Xher2k<Complex,Real>(queue, nullptr);
- Xtrmm<Real>(queue, nullptr); Xtrmm<Complex>(queue, nullptr);
-
- // Runs all the non-BLAS set-up functions
- Xomatcopy<Real>(queue, nullptr); Xomatcopy<Complex>(queue, nullptr);
-
- } catch(const RuntimeErrorCode &e) {
- if (e.status() != StatusCode::kNoDoublePrecision &&
- e.status() != StatusCode::kNoHalfPrecision) {
- throw;
- }
- }
-}
-
-// Fills the cache with all binaries for a specific device
-// TODO: Add half-precision FP16 set-up calls
-StatusCode FillCache(const cl_device_id device) {
- try {
-
- // Creates a sample context and queue to match the normal routine calling conventions
- auto device_cpp = Device(device);
- auto context = Context(device_cpp);
- auto queue = Queue(context, device_cpp);
-
- FillCacheForPrecision<float, float2>(queue);
- FillCacheForPrecision<double, double2>(queue);
-
- } catch (...) { return DispatchException(); }
- return StatusCode::kSuccess;
-}
-
-// =================================================================================================
-
-// Overrides the tuning parameters for this device-precision-kernel combination
-StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name,
- const Precision precision,
- const std::unordered_map<std::string,size_t> &parameters) {
- try {
-
- // Retrieves the device name
- const auto device_cpp = Device(device);
- const auto platform_id = device_cpp.PlatformID();
- const auto device_name = GetDeviceName(device_cpp);
-
- // Retrieves the current database values to verify whether the new ones are complete
- auto in_cache = false;
- auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache);
- if (!in_cache) {
- log_debug("Searching database for kernel '" + kernel_name + "'");
- current_database = Database(device_cpp, kernel_name, precision, {});
- }
-
- // Verifies the parameters size
- const auto current_parameter_names = current_database.GetParameterNames();
- if (current_parameter_names.size() != parameters.size()) {
- return StatusCode::kMissingOverrideParameter;
- }
-
- // Retrieves the names and values separately and in the same order as the existing database
- auto parameter_values = database::Params{0};
- auto i = size_t{0};
- for (const auto &current_param : current_parameter_names) {
- if (parameters.find(current_param) == parameters.end()) {
- return StatusCode::kMissingOverrideParameter;
- }
- const auto parameter_value = parameters.at(current_param);
- parameter_values[i] = parameter_value;
- ++i;
- }
-
- // Creates a small custom database based on the provided parameters
- const auto database_device = database::DatabaseDevice{database::kDeviceNameDefault, parameter_values};
- const auto database_architecture = database::DatabaseArchitecture{"default", {database_device}};
- const auto database_vendor = database::DatabaseVendor{database::kDeviceTypeAll, "default", {database_architecture}};
- const auto database_entry = database::DatabaseEntry{kernel_name, precision, current_parameter_names, {database_vendor}};
- const auto database_entries = std::vector<database::DatabaseEntry>{database_entry};
- const auto database = Database(device_cpp, kernel_name, precision, database_entries);
-
- // Removes the old database entry and stores the new one in the cache
- DatabaseCache::Instance().Remove(DatabaseKey{platform_id, device, precision, kernel_name});
- DatabaseCache::Instance().Store(DatabaseKey{platform_id, device, precision, kernel_name}, Database(database));
-
- } catch (...) { return DispatchException(); }
- return StatusCode::kSuccess;
-}
// =================================================================================================
} // namespace clblast
diff --git a/src/routines/routines.hpp b/src/routines/routines.hpp
new file mode 100644
index 00000000..9e7768b9
--- /dev/null
+++ b/src/routines/routines.hpp
@@ -0,0 +1,76 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file contains all the includes of all the routines in CLBlast.
+//
+// =================================================================================================
+
+#ifndef CLBLAST_ROUTINES_ROUTINES_H_
+#define CLBLAST_ROUTINES_ROUTINES_H_
+
+// BLAS level-1 includes
+#include "routines/level1/xswap.hpp"
+#include "routines/level1/xscal.hpp"
+#include "routines/level1/xcopy.hpp"
+#include "routines/level1/xaxpy.hpp"
+#include "routines/level1/xdot.hpp"
+#include "routines/level1/xdotu.hpp"
+#include "routines/level1/xdotc.hpp"
+#include "routines/level1/xnrm2.hpp"
+#include "routines/level1/xasum.hpp"
+#include "routines/level1/xsum.hpp" // non-BLAS routine
+#include "routines/level1/xamax.hpp"
+#include "routines/level1/xamin.hpp" // non-BLAS routine
+#include "routines/level1/xmax.hpp" // non-BLAS routine
+#include "routines/level1/xmin.hpp" // non-BLAS routine
+
+// BLAS level-2 includes
+#include "routines/level2/xgemv.hpp"
+#include "routines/level2/xgbmv.hpp"
+#include "routines/level2/xhemv.hpp"
+#include "routines/level2/xhbmv.hpp"
+#include "routines/level2/xhpmv.hpp"
+#include "routines/level2/xsymv.hpp"
+#include "routines/level2/xsbmv.hpp"
+#include "routines/level2/xspmv.hpp"
+#include "routines/level2/xtrmv.hpp"
+#include "routines/level2/xtbmv.hpp"
+#include "routines/level2/xtpmv.hpp"
+#include "routines/level2/xtrsv.hpp"
+#include "routines/level2/xger.hpp"
+#include "routines/level2/xgeru.hpp"
+#include "routines/level2/xgerc.hpp"
+#include "routines/level2/xher.hpp"
+#include "routines/level2/xhpr.hpp"
+#include "routines/level2/xher2.hpp"
+#include "routines/level2/xhpr2.hpp"
+#include "routines/level2/xsyr.hpp"
+#include "routines/level2/xspr.hpp"
+#include "routines/level2/xsyr2.hpp"
+#include "routines/level2/xspr2.hpp"
+
+// BLAS level-3 includes
+#include "routines/level3/xgemm.hpp"
+#include "routines/level3/xsymm.hpp"
+#include "routines/level3/xhemm.hpp"
+#include "routines/level3/xsyrk.hpp"
+#include "routines/level3/xherk.hpp"
+#include "routines/level3/xsyr2k.hpp"
+#include "routines/level3/xher2k.hpp"
+#include "routines/level3/xtrmm.hpp"
+#include "routines/level3/xtrsm.hpp"
+
+// Level-x includes (non-BLAS)
+#include "routines/levelx/xomatcopy.hpp"
+#include "routines/levelx/xim2col.hpp"
+#include "routines/levelx/xaxpybatched.hpp"
+#include "routines/levelx/xgemmbatched.hpp"
+
+// CLBLAST_ROUTINES_ROUTINES_H_
+#endif