diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-06-18 18:16:14 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-06-18 18:16:14 +0200 |
commit | bacb5d2bb2ea7b141034878090aca850db8f9d00 (patch) | |
tree | 7315f72f18c93fa02302e58e2718d2fbfd9db361 /src | |
parent | 7b4c0e1cf03a94077c20f7f12ef15fb8717c74ca (diff) |
Clean-up of the routine class, moved RunKernel to the routine/common file
Diffstat (limited to 'src')
-rw-r--r-- | src/clblast.cc | 14 | ||||
-rw-r--r-- | src/routine.cc | 45 | ||||
-rw-r--r-- | src/routines/common.cc | 65 |
3 files changed, 74 insertions, 50 deletions
diff --git a/src/clblast.cc b/src/clblast.cc index 2d6776d0..d0f0c937 100644 --- a/src/clblast.cc +++ b/src/clblast.cc @@ -29,10 +29,10 @@ #include "internal/routines/level1/xdotc.h" #include "internal/routines/level1/xnrm2.h" #include "internal/routines/level1/xasum.h" -#include "internal/routines/level1/xsum.h" // non-BLAS function +#include "internal/routines/level1/xsum.h" // non-BLAS routine #include "internal/routines/level1/xamax.h" -#include "internal/routines/level1/xmax.h" // non-BLAS function -#include "internal/routines/level1/xmin.h" // non-BLAS function +#include "internal/routines/level1/xmax.h" // non-BLAS routine +#include "internal/routines/level1/xmin.h" // non-BLAS routine // BLAS level-2 includes #include "internal/routines/level2/xgemv.h" @@ -68,7 +68,7 @@ #include "internal/routines/level3/xher2k.h" #include "internal/routines/level3/xtrmm.h" -// Extra includes (level-x) +// Level-x includes (non-BLAS) #include "internal/routines/levelx/xomatcopy.h" namespace clblast { @@ -2123,6 +2123,7 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose, StatusCode ClearCache() { return CacheClearAll(); } // Fills the cache with all binaries for a specific device +// TODO: Add half-precision FP16 set-up calls StatusCode FillCache(const cl_device_id device) { try { @@ -2171,7 +2172,7 @@ StatusCode FillCache(const cl_device_id device) { Xsyr2<float>(queue, nullptr).SetUp(); Xsyr2<double>(queue, nullptr).SetUp(); Xspr2<float>(queue, nullptr).SetUp(); Xspr2<double>(queue, nullptr).SetUp(); - // Runs all the level 1 set-up functions + // Runs all the level 3 set-up functions Xgemm<float>(queue, nullptr).SetUp(); Xgemm<double>(queue, nullptr).SetUp(); Xgemm<float2>(queue, nullptr).SetUp(); Xgemm<double2>(queue, nullptr).SetUp(); Xsymm<float>(queue, nullptr).SetUp(); Xsymm<double>(queue, nullptr).SetUp(); Xsymm<float2>(queue, nullptr).SetUp(); Xsymm<double2>(queue, nullptr).SetUp(); Xhemm<float2>(queue, nullptr).SetUp(); Xhemm<double2>(queue, nullptr).SetUp(); @@ -2181,6 +2182,9 @@ StatusCode FillCache(const cl_device_id device) { Xher2k<float2,float>(queue, nullptr).SetUp(); Xher2k<double2,double>(queue, nullptr).SetUp(); Xtrmm<float>(queue, nullptr).SetUp(); Xtrmm<double>(queue, nullptr).SetUp(); Xtrmm<float2>(queue, nullptr).SetUp(); Xtrmm<double2>(queue, nullptr).SetUp(); + // Runs all the level 3 set-up functions + Xomatcopy<float>(queue, nullptr).SetUp(); Xomatcopy<double>(queue, nullptr).SetUp(); Xomatcopy<float2>(queue, nullptr).SetUp(); Xomatcopy<double2>(queue, nullptr).SetUp(); + } catch (...) { return StatusCode::kBuildProgramFailure; } return StatusCode::kSuccess; } diff --git a/src/routine.cc b/src/routine.cc index 9b1640b5..11633ede 100644 --- a/src/routine.cc +++ b/src/routine.cc @@ -128,49 +128,4 @@ StatusCode Routine::SetUp() { } // ================================================================================================= - -// Enqueues a kernel, waits for completion, and checks for errors -StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device, - std::vector<size_t> global, const std::vector<size_t> &local, - EventPointer event, std::vector<Event>& waitForEvents) { - - // Tests for validity of the local thread sizes - if (local.size() > device.MaxWorkItemDimensions()) { - return StatusCode::kInvalidLocalNumDimensions; - } - const auto max_work_item_sizes = device.MaxWorkItemSizes(); - for (auto i=size_t{0}; i<local.size(); ++i) { - if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; } - } - auto local_size = size_t{1}; - for (auto &item: local) { local_size *= item; } - if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; } - - // Make sure the global thread sizes are at least equal to the local sizes - for (auto i=size_t{0}; i<global.size(); ++i) { - if (global[i] < local[i]) { global[i] = local[i]; } - } - - // Tests for local memory usage - const auto local_mem_usage = kernel.LocalMemUsage(device); - if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; } - - // Launches the kernel (and checks for launch errors) - try { - kernel.Launch(queue, global, local, event, waitForEvents); - } catch (...) { return StatusCode::kKernelLaunchError; } - - // No errors, normal termination of this function - return StatusCode::kSuccess; -} - -// As above, but without an event waiting list -StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device, - std::vector<size_t> global, const std::vector<size_t> &local, - EventPointer event) { - auto emptyWaitingList = std::vector<Event>(); - return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList); -} - -// ================================================================================================= } // namespace clblast diff --git a/src/routines/common.cc b/src/routines/common.cc new file mode 100644 index 00000000..561a1bd8 --- /dev/null +++ b/src/routines/common.cc @@ -0,0 +1,65 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren <www.cedricnugteren.nl> +// +// This file implements the common routine functions (see the header for more information). +// +// ================================================================================================= + +#include <vector> + +#include "internal/routines/common.h" + +namespace clblast { +// ================================================================================================= + +// Enqueues a kernel, waits for completion, and checks for errors +StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device, + std::vector<size_t> global, const std::vector<size_t> &local, + EventPointer event, std::vector<Event>& waitForEvents) { + + // Tests for validity of the local thread sizes + if (local.size() > device.MaxWorkItemDimensions()) { + return StatusCode::kInvalidLocalNumDimensions; + } + const auto max_work_item_sizes = device.MaxWorkItemSizes(); + for (auto i=size_t{0}; i<local.size(); ++i) { + if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; } + } + auto local_size = size_t{1}; + for (auto &item: local) { local_size *= item; } + if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; } + + // Make sure the global thread sizes are at least equal to the local sizes + for (auto i=size_t{0}; i<global.size(); ++i) { + if (global[i] < local[i]) { global[i] = local[i]; } + } + + // Tests for local memory usage + const auto local_mem_usage = kernel.LocalMemUsage(device); + if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; } + + // Launches the kernel (and checks for launch errors) + try { + kernel.Launch(queue, global, local, event, waitForEvents); + } catch (...) { return StatusCode::kKernelLaunchError; } + + // No errors, normal termination of this function + return StatusCode::kSuccess; +} + +// As above, but without an event waiting list +StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device, + std::vector<size_t> global, const std::vector<size_t> &local, + EventPointer event) { + auto emptyWaitingList = std::vector<Event>(); + return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList); +} + +// ================================================================================================= +} // namespace clblast |