summaryrefslogtreecommitdiff
path: root/test/correctness/testabc.h
blob: 41100db72afeaaf9dac795aa3287bda35377a134 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under the MIT license. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
//   Cedric Nugteren <www.cedricnugteren.nl>
//
// This file tests any mat-mat-mat (A,B,C) routine. It contains two types of tests: one testing
// all sorts of input combinations, and one deliberatly testing with invalid values.
//
// =================================================================================================

#ifndef CLBLAST_TEST_CORRECTNESS_TESTABC_H_
#define CLBLAST_TEST_CORRECTNESS_TESTABC_H_

#include <vector>
#include <string>

#include "correctness/tester.h"

namespace clblast {
// =================================================================================================

// See comment at top of file for a description of the class
template <typename T>
class TestABC: public Tester<T> {
 public:

  // Uses several variables from the Tester class
  using Tester<T>::context_;
  using Tester<T>::queue_;
  using Tester<T>::kLayouts;
  using Tester<T>::kTransposes;

  // Uses several helper functions from the Tester class
  using Tester<T>::TestStart;
  using Tester<T>::TestEnd;
  using Tester<T>::TestSimilarity;
  using Tester<T>::TestErrorCount;
  using Tester<T>::TestErrorCodes;
  using Tester<T>::GetExampleScalars;
  using Tester<T>::GetOffsets;
  using Tester<T>::PrecisionSupported;

  // Test settings for the regular test. Append to this list in case more tests are required.
  const std::vector<size_t> kMatrixDims = { 7, 64 };
  const std::vector<size_t> kOffsets = GetOffsets();
  const std::vector<T> kAlphaValues = GetExampleScalars();
  const std::vector<T> kBetaValues = GetExampleScalars();

  // Test settings for the invalid test
  const size_t kBufferSize = 64;

  // Shorthand for a BLAS routine
  using Routine = std::function<StatusCode(const Arguments<T>&,
                                           const Buffer&, const Buffer&, const Buffer&,
                                           CommandQueue&)>;

  // Constructor, initializes the base class tester and input data
  TestABC(int argc, char *argv[], const bool silent,
          const std::string &name, const std::vector<std::string> &options,
          const Routine clblast_lambda, const Routine clblas_lambda);

  // The test functions, taking no inputs
  void TestRegular(Arguments<T> &args, const std::string &name, const bool symmetric);
  void TestInvalidBufferSizes(Arguments<T> &args, const std::string &name);

 private:

  // Source data to test with
  std::vector<T> a_source_;
  std::vector<T> b_source_;
  std::vector<T> c_source_;
  
  // The routines to test
  Routine clblast_lambda_;
  Routine clblas_lambda_;
};

// =================================================================================================
} // namespace clblast

// CLBLAST_TEST_CORRECTNESS_TESTABC_H_
#endif