summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-06-18 18:16:14 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-06-18 18:16:14 +0200
commitbacb5d2bb2ea7b141034878090aca850db8f9d00 (patch)
tree7315f72f18c93fa02302e58e2718d2fbfd9db361 /src
parent7b4c0e1cf03a94077c20f7f12ef15fb8717c74ca (diff)
Clean-up of the routine class, moved RunKernel to the routine/common file
Diffstat (limited to 'src')
-rw-r--r--src/clblast.cc14
-rw-r--r--src/routine.cc45
-rw-r--r--src/routines/common.cc65
3 files changed, 74 insertions, 50 deletions
diff --git a/src/clblast.cc b/src/clblast.cc
index 2d6776d0..d0f0c937 100644
--- a/src/clblast.cc
+++ b/src/clblast.cc
@@ -29,10 +29,10 @@
#include "internal/routines/level1/xdotc.h"
#include "internal/routines/level1/xnrm2.h"
#include "internal/routines/level1/xasum.h"
-#include "internal/routines/level1/xsum.h" // non-BLAS function
+#include "internal/routines/level1/xsum.h" // non-BLAS routine
#include "internal/routines/level1/xamax.h"
-#include "internal/routines/level1/xmax.h" // non-BLAS function
-#include "internal/routines/level1/xmin.h" // non-BLAS function
+#include "internal/routines/level1/xmax.h" // non-BLAS routine
+#include "internal/routines/level1/xmin.h" // non-BLAS routine
// BLAS level-2 includes
#include "internal/routines/level2/xgemv.h"
@@ -68,7 +68,7 @@
#include "internal/routines/level3/xher2k.h"
#include "internal/routines/level3/xtrmm.h"
-// Extra includes (level-x)
+// Level-x includes (non-BLAS)
#include "internal/routines/levelx/xomatcopy.h"
namespace clblast {
@@ -2123,6 +2123,7 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
StatusCode ClearCache() { return CacheClearAll(); }
// Fills the cache with all binaries for a specific device
+// TODO: Add half-precision FP16 set-up calls
StatusCode FillCache(const cl_device_id device) {
try {
@@ -2171,7 +2172,7 @@ StatusCode FillCache(const cl_device_id device) {
Xsyr2<float>(queue, nullptr).SetUp(); Xsyr2<double>(queue, nullptr).SetUp();
Xspr2<float>(queue, nullptr).SetUp(); Xspr2<double>(queue, nullptr).SetUp();
- // Runs all the level 1 set-up functions
+ // Runs all the level 3 set-up functions
Xgemm<float>(queue, nullptr).SetUp(); Xgemm<double>(queue, nullptr).SetUp(); Xgemm<float2>(queue, nullptr).SetUp(); Xgemm<double2>(queue, nullptr).SetUp();
Xsymm<float>(queue, nullptr).SetUp(); Xsymm<double>(queue, nullptr).SetUp(); Xsymm<float2>(queue, nullptr).SetUp(); Xsymm<double2>(queue, nullptr).SetUp();
Xhemm<float2>(queue, nullptr).SetUp(); Xhemm<double2>(queue, nullptr).SetUp();
@@ -2181,6 +2182,9 @@ StatusCode FillCache(const cl_device_id device) {
Xher2k<float2,float>(queue, nullptr).SetUp(); Xher2k<double2,double>(queue, nullptr).SetUp();
Xtrmm<float>(queue, nullptr).SetUp(); Xtrmm<double>(queue, nullptr).SetUp(); Xtrmm<float2>(queue, nullptr).SetUp(); Xtrmm<double2>(queue, nullptr).SetUp();
+ // Runs all the level 3 set-up functions
+ Xomatcopy<float>(queue, nullptr).SetUp(); Xomatcopy<double>(queue, nullptr).SetUp(); Xomatcopy<float2>(queue, nullptr).SetUp(); Xomatcopy<double2>(queue, nullptr).SetUp();
+
} catch (...) { return StatusCode::kBuildProgramFailure; }
return StatusCode::kSuccess;
}
diff --git a/src/routine.cc b/src/routine.cc
index 9b1640b5..11633ede 100644
--- a/src/routine.cc
+++ b/src/routine.cc
@@ -128,49 +128,4 @@ StatusCode Routine::SetUp() {
}
// =================================================================================================
-
-// Enqueues a kernel, waits for completion, and checks for errors
-StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device,
- std::vector<size_t> global, const std::vector<size_t> &local,
- EventPointer event, std::vector<Event>& waitForEvents) {
-
- // Tests for validity of the local thread sizes
- if (local.size() > device.MaxWorkItemDimensions()) {
- return StatusCode::kInvalidLocalNumDimensions;
- }
- const auto max_work_item_sizes = device.MaxWorkItemSizes();
- for (auto i=size_t{0}; i<local.size(); ++i) {
- if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; }
- }
- auto local_size = size_t{1};
- for (auto &item: local) { local_size *= item; }
- if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; }
-
- // Make sure the global thread sizes are at least equal to the local sizes
- for (auto i=size_t{0}; i<global.size(); ++i) {
- if (global[i] < local[i]) { global[i] = local[i]; }
- }
-
- // Tests for local memory usage
- const auto local_mem_usage = kernel.LocalMemUsage(device);
- if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; }
-
- // Launches the kernel (and checks for launch errors)
- try {
- kernel.Launch(queue, global, local, event, waitForEvents);
- } catch (...) { return StatusCode::kKernelLaunchError; }
-
- // No errors, normal termination of this function
- return StatusCode::kSuccess;
-}
-
-// As above, but without an event waiting list
-StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device,
- std::vector<size_t> global, const std::vector<size_t> &local,
- EventPointer event) {
- auto emptyWaitingList = std::vector<Event>();
- return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList);
-}
-
-// =================================================================================================
} // namespace clblast
diff --git a/src/routines/common.cc b/src/routines/common.cc
new file mode 100644
index 00000000..561a1bd8
--- /dev/null
+++ b/src/routines/common.cc
@@ -0,0 +1,65 @@
+
+// =================================================================================================
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+// This file implements the common routine functions (see the header for more information).
+//
+// =================================================================================================
+
+#include <vector>
+
+#include "internal/routines/common.h"
+
+namespace clblast {
+// =================================================================================================
+
+// Enqueues a kernel, waits for completion, and checks for errors
+StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
+ std::vector<size_t> global, const std::vector<size_t> &local,
+ EventPointer event, std::vector<Event>& waitForEvents) {
+
+ // Tests for validity of the local thread sizes
+ if (local.size() > device.MaxWorkItemDimensions()) {
+ return StatusCode::kInvalidLocalNumDimensions;
+ }
+ const auto max_work_item_sizes = device.MaxWorkItemSizes();
+ for (auto i=size_t{0}; i<local.size(); ++i) {
+ if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; }
+ }
+ auto local_size = size_t{1};
+ for (auto &item: local) { local_size *= item; }
+ if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; }
+
+ // Make sure the global thread sizes are at least equal to the local sizes
+ for (auto i=size_t{0}; i<global.size(); ++i) {
+ if (global[i] < local[i]) { global[i] = local[i]; }
+ }
+
+ // Tests for local memory usage
+ const auto local_mem_usage = kernel.LocalMemUsage(device);
+ if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; }
+
+ // Launches the kernel (and checks for launch errors)
+ try {
+ kernel.Launch(queue, global, local, event, waitForEvents);
+ } catch (...) { return StatusCode::kKernelLaunchError; }
+
+ // No errors, normal termination of this function
+ return StatusCode::kSuccess;
+}
+
+// As above, but without an event waiting list
+StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
+ std::vector<size_t> global, const std::vector<size_t> &local,
+ EventPointer event) {
+ auto emptyWaitingList = std::vector<Event>();
+ return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList);
+}
+
+// =================================================================================================
+} // namespace clblast