summaryrefslogtreecommitdiff
path: root/src/routines/common.cpp
blob: 21e16954bcaa16427bf7efb92b0a304d5bd7f919 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
//   Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the common routine functions (see the header for more information).
//
// =================================================================================================

#include <vector>
#include <chrono>

#include "routines/common.hpp"

namespace clblast {
// =================================================================================================

// Enqueues a kernel, waits for completion, and checks for errors
StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
                     std::vector<size_t> global, const std::vector<size_t> &local,
                     EventPointer event, const std::vector<Event> &waitForEvents) {

  // Tests for validity of the local thread sizes
  if (local.size() > device.MaxWorkItemDimensions()) {
    return StatusCode::kInvalidLocalNumDimensions; 
  }
  const auto max_work_item_sizes = device.MaxWorkItemSizes();
  for (auto i=size_t{0}; i<local.size(); ++i) {
    if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; }
  }
  auto local_size = size_t{1};
  for (auto &item: local) { local_size *= item; }
  if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; }

  // Make sure the global thread sizes are at least equal to the local sizes
  for (auto i=size_t{0}; i<global.size(); ++i) {
    if (global[i] < local[i]) { global[i] = local[i]; }
  }

  // Tests for local memory usage
  const auto local_mem_usage = kernel.LocalMemUsage(device);
  if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; }

  // Prints the name of the kernel to launch in case of debugging in verbose mode
  #ifdef VERBOSE
    queue.Finish();
    printf("[DEBUG] Running kernel '%s'\n", kernel.GetFunctionName().c_str());
    const auto start_time = std::chrono::steady_clock::now();
  #endif

  // Launches the kernel (and checks for launch errors)
  try {
    kernel.Launch(queue, global, local, event, waitForEvents);
  } catch (...) { return StatusCode::kKernelLaunchError; }

  // Prints the elapsed execution time in case of debugging in verbose mode
  #ifdef VERBOSE
    queue.Finish();
    const auto elapsed_time = std::chrono::steady_clock::now() - start_time;
    const auto timing = std::chrono::duration<double,std::milli>(elapsed_time).count();
    printf("[DEBUG] Completed kernel in %.2lf ms\n", timing);
  #endif

  // No errors, normal termination of this function
  return StatusCode::kSuccess;
}

// As above, but without an event waiting list
StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
                     std::vector<size_t> global, const std::vector<size_t> &local,
                     EventPointer event) {
  auto emptyWaitingList = std::vector<Event>();
  return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList);
}

// =================================================================================================
} // namespace clblast