diff options
author | Tarmo Räntilä <trantila@iki.fi> | 2019-12-09 22:17:24 +0200 |
---|---|---|
committer | Tarmo Räntilä <trantila@iki.fi> | 2019-12-09 22:17:24 +0200 |
commit | 21b66ca76140be9ac30811e7648abe3837e19177 (patch) | |
tree | a75de219a5393a848fa083f59f7add5fcc22aa42 /src/utilities | |
parent | bf50c4e53e1815d4b376f35b5be5c747cd857414 (diff) |
Reduce TestMatrix calls for xgemmstridedbatched.
Replace the looped test by a single one with the offset of the last batch.
Diffstat (limited to 'src/utilities')
-rw-r--r-- | src/utilities/buffer_test.hpp | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/src/utilities/buffer_test.hpp b/src/utilities/buffer_test.hpp index 9cecce97..4a2a2c95 100644 --- a/src/utilities/buffer_test.hpp +++ b/src/utilities/buffer_test.hpp @@ -134,6 +134,35 @@ void TestBatchedMatrixC(const size_t one, const size_t two, const Buffer<T>& buf } // ================================================================================================= + +// Tests matrix 'A' for validity in a strided batched setting +template <typename T> +void TestStridedBatchedMatrixA(const size_t one, const size_t two, const Buffer<T>& buffer, + const size_t offset, const size_t stride, const size_t batch_count, + const size_t ld, const bool test_lead_dim = true) { + const auto last_batch_offset = (batch_count - 1) * stride; + TestMatrixA(one, two, buffer, offset + last_batch_offset, ld, test_lead_dim); +} + +// Tests matrix 'B' for validity in a strided batched setting +template <typename T> +void TestStridedBatchedMatrixB(const size_t one, const size_t two, const Buffer<T>& buffer, + const size_t offset, const size_t stride, const size_t batch_count, + const size_t ld, const bool test_lead_dim = true) { + const auto last_batch_offset = (batch_count - 1) * stride; + TestMatrixB(one, two, buffer, offset + last_batch_offset, ld, test_lead_dim); +} + +// Tests matrix 'C' for validity in a strided batched setting +template <typename T> +void TestStridedBatchedMatrixC(const size_t one, const size_t two, const Buffer<T>& buffer, + const size_t offset, const size_t stride, const size_t batch_count, + const size_t ld) { + const auto last_batch_offset = (batch_count - 1) * stride; + TestMatrixC(one, two, buffer, offset + last_batch_offset, ld); +} + +// ================================================================================================= } // namespace clblast // CLBLAST_BUFFER_TEST_H_ |