summaryrefslogtreecommitdiff
path: root/src/routines/levelx/xconvgemm.cpp
blob: 8cb8093ca4a7a6320d52b10da49dcfd7b500cbb2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
//   Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the Xconvgemm class (see the header for information about the class).
//
// =================================================================================================

#include <string>
#include <vector>

#include "routines/levelx/xconvgemm.hpp"
#include "routines/levelx/xim2col.hpp"

namespace clblast {
// =================================================================================================

// Constructor: forwards to base class constructor
template <typename T>
Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name):
    Routine(queue, event, name, {"XgemmDirect"},
        PrecisionValue<T>(), {}, {
            #include "../../kernels/level3/level3.opencl"
            , // separated in multiple parts to prevent C1091 in MSVC 2013
            #include "../../kernels/level3/xgemm_direct_part1.opencl"
            #include "../../kernels/level3/xgemm_direct_part2.opencl"
            #include "../../kernels/level3/xgemm_direct_part3.opencl"
            , // separated in multiple parts to prevent C1091 in MSVC 2013
            #include "../../kernels/level3/xconvgemm.opencl"
        }) {
}

// =================================================================================================

template <typename T>
void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const size_t width,
                              const size_t kernel_h, const size_t kernel_w, const size_t pad_h,
                              const size_t pad_w, const size_t stride_h, const size_t stride_w,
                              const size_t dilation_h, const size_t dilation_w,
                              const size_t num_kernels, const size_t batch_count,
                              const Buffer<T> &im_buffer, const size_t im_offset,
                              const Buffer<T> &kernel_buffer, const size_t kernel_offset,
                              const Buffer<T> &result_buffer, const size_t result_offset) {

  // Tests for a valid batch count
  if (batch_count == 0) {
    throw BLASError(StatusCode::kInvalidBatchCount);
  }

  // Makes sure all dimensions are larger than zero
  if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0)) {
    throw BLASError(StatusCode::kInvalidDimension);
  }

  // Sets the output height and width
  const auto size_h = height + 2 * pad_h;
  const auto padding_h = dilation_h * (kernel_h - 1) + 1;
  const auto output_h = (size_h >= padding_h) ? (size_h - padding_h) / stride_h + 1 : 1;
  const auto size_w = width + 2 * pad_w;
  const auto padding_w = dilation_w * (kernel_w - 1) + 1;
  const auto output_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1;

  // Sets other useful variables
  const auto patch_size = kernel_h * kernel_w * channels;
  const auto num_patches = output_h * output_w;

  // Approach: im2col + GEMM
  //      result = GEMM(im2col(image), kernel)

  // Temporary col matrix
  const auto col_size = patch_size * num_patches * batch_count;
  auto col_buffer = Buffer<T>(context_, col_size);

  // Loops over each batch
  for (auto batch_id = size_t{0}; batch_id < batch_count; ++batch_id) {

    // im2col
    const auto im_batch_offset = batch_id * channels * height * width + im_offset;
    const auto col_batch_offset = batch_id * patch_size * num_patches;
    auto im2col_event = Event();
    auto im2col = Xim2col<T>(queue_, im2col_event.pointer());
    im2col.DoIm2col(channels, height, width, kernel_h, kernel_w,
                    pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
                    im_buffer, im_batch_offset,
                    col_buffer, col_batch_offset);
    im2col_event.WaitForCompletion();
  }

  // Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
  const auto col_stride = patch_size * num_patches;
  const auto result_stride = num_kernels * output_h * output_w;

  // Tests the matrices for validity
  TestMatrixB(patch_size, num_kernels, kernel_buffer, kernel_offset, patch_size);
  for (auto batch = size_t{0}; batch < batch_count; ++batch) {
    TestMatrixA(num_patches, patch_size, col_buffer, col_stride * batch, num_patches);
    TestMatrixC(num_patches, num_kernels, result_buffer, result_offset + result_stride * batch, num_patches);
  }

  // Retrieves the proper XgemmDirect kernel from the compiled binary
  auto kernel = Kernel(program_, "Xconvgemm");

  // Sets the kernel arguments
  kernel.SetArgument(0, static_cast<int>(num_patches));
  kernel.SetArgument(1, static_cast<int>(num_kernels));
  kernel.SetArgument(2, static_cast<int>(patch_size));
  kernel.SetArgument(3, col_buffer());
  kernel.SetArgument(4, static_cast<int>(0));
  kernel.SetArgument(5, static_cast<int>(col_stride));
  kernel.SetArgument(6, kernel_buffer());
  kernel.SetArgument(7, static_cast<int>(kernel_offset));
  kernel.SetArgument(8, result_buffer());
  kernel.SetArgument(9, static_cast<int>(result_offset));
  kernel.SetArgument(10, static_cast<int>(result_stride));

  // Computes the global and local thread sizes
  const auto m_ceiled = Ceil(num_patches, db_["WGD"]);
  const auto n_ceiled = Ceil(num_kernels, db_["WGD"]);
  const auto global = std::vector<size_t>{
      (m_ceiled * db_["MDIMCD"]) / db_["WGD"],
      (n_ceiled * db_["NDIMCD"]) / db_["WGD"],
      batch_count
  };
  const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"], 1};

  // Launches the kernel
  RunKernel(kernel, queue_, device_, global, local, event_);
}

// =================================================================================================

// Compiles the templated class
template class Xconvgemm<half>;
template class Xconvgemm<float>;
template class Xconvgemm<double>;
template class Xconvgemm<float2>;
template class Xconvgemm<double2>;

// =================================================================================================
} // namespace clblast