diff options
Diffstat (limited to 'src/routines')
-rw-r--r-- | src/routines/common.hpp | 9 | ||||
-rw-r--r-- | src/routines/level2/xtrsv.cpp | 13 | ||||
-rw-r--r-- | src/routines/level2/xtrsv.hpp | 4 | ||||
-rw-r--r-- | src/routines/level3/xgemm.cpp | 1 | ||||
-rw-r--r-- | src/routines/level3/xherk.cpp | 1 | ||||
-rw-r--r-- | src/routines/level3/xsyrk.cpp | 1 | ||||
-rw-r--r-- | src/routines/level3/xtrsm.cpp | 2 | ||||
-rw-r--r-- | src/routines/level3/xtrsm.hpp | 1 | ||||
-rw-r--r-- | src/routines/levelx/xgemmbatched.cpp | 1 | ||||
-rw-r--r-- | src/routines/levelx/xgemmstridedbatched.cpp | 1 |
10 files changed, 23 insertions, 11 deletions
diff --git a/src/routines/common.hpp b/src/routines/common.hpp index c30a2e0e..c6db0152 100644 --- a/src/routines/common.hpp +++ b/src/routines/common.hpp @@ -76,6 +76,7 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, // Determines the right kernel auto kernel_name = std::string{}; + auto pad_kernel = false; if (do_transpose) { if (use_fast_kernel && IsMultiple(src_ld, db["TRA_WPT"]) && @@ -85,7 +86,8 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, } else { use_fast_kernel = false; - kernel_name = (do_pad) ? "TransposePadMatrix" : "TransposeMatrix"; + pad_kernel = (do_pad || do_conjugate); + kernel_name = (pad_kernel) ? "TransposePadMatrix" : "TransposeMatrix"; } } else { @@ -97,7 +99,8 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, } else { use_fast_kernel = false; - kernel_name = (do_pad) ? "CopyPadMatrix" : "CopyMatrix"; + pad_kernel = do_pad; + kernel_name = (pad_kernel) ? "CopyPadMatrix" : "CopyMatrix"; } } @@ -123,7 +126,7 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device, kernel.SetArgument(8, static_cast<int>(dest_offset)); kernel.SetArgument(9, dest()); kernel.SetArgument(10, GetRealArg(alpha)); - if (do_pad) { + if (pad_kernel) { kernel.SetArgument(11, static_cast<int>(do_conjugate)); } else { diff --git a/src/routines/level2/xtrsv.cpp b/src/routines/level2/xtrsv.cpp index 76401753..2a5a5664 100644 --- a/src/routines/level2/xtrsv.cpp +++ b/src/routines/level2/xtrsv.cpp @@ -33,7 +33,8 @@ void Xtrsv<T>::Substitution(const Layout layout, const Triangle triangle, const size_t n, const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_inc, - const Buffer<T> &x_buffer, const size_t x_offset, const size_t x_inc) { + const Buffer<T> &x_buffer, const size_t x_offset, const size_t x_inc, + EventPointer event) { if (n > db_["TRSV_BLOCK_SIZE"]) { throw BLASError(StatusCode::kUnexpectedError); }; @@ -69,9 +70,7 @@ void Xtrsv<T>::Substitution(const Layout layout, const Triangle triangle, // Launches the kernel const auto local = std::vector<size_t>{db_["TRSV_BLOCK_SIZE"]}; const auto global = std::vector<size_t>{Ceil(n, db_["TRSV_BLOCK_SIZE"])}; - auto event = Event(); - RunKernel(kernel, queue_, device_, global, local, event.pointer()); - event.WaitForCompletion(); + RunKernel(kernel, queue_, device_, global, local, event); } // ================================================================================================= @@ -146,14 +145,16 @@ void Xtrsv<T>::DoTrsv(const Layout layout, const Triangle triangle, } // Runs the triangular substitution for the block size + auto sub_event = Event(); Substitution(layout, triangle, a_transpose, diagonal, block_size, a_buffer, a_offset + col + col*a_ld, a_ld, b_buffer, b_offset + col*b_inc, b_inc, - x_buffer, x_offset + col*x_inc, x_inc); + x_buffer, x_offset + col*x_inc, x_inc, sub_event.pointer()); + sub_event.WaitForCompletion(); } // Retrieves the results - x_buffer.CopyTo(queue_, x_size, b_buffer); + x_buffer.CopyToAsync(queue_, x_size, b_buffer, event_); } // ================================================================================================= diff --git a/src/routines/level2/xtrsv.hpp b/src/routines/level2/xtrsv.hpp index 67e626a1..8a900a35 100644 --- a/src/routines/level2/xtrsv.hpp +++ b/src/routines/level2/xtrsv.hpp @@ -32,6 +32,7 @@ class Xtrsv: public Xgemv<T> { using Xgemv<T>::device_; using Xgemv<T>::db_; using Xgemv<T>::program_; + using Xgemv<T>::event_; using Xgemv<T>::DoGemv; // Constructor @@ -50,7 +51,8 @@ class Xtrsv: public Xgemv<T> { const size_t n, const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_inc, - const Buffer<T> &x_buffer, const size_t offset_x, const size_t x_inc); + const Buffer<T> &x_buffer, const size_t offset_x, const size_t x_inc, + EventPointer event); }; // ================================================================================================= diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index fd5a20db..cb24460a 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -40,6 +40,7 @@ Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name): , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" }) { diff --git a/src/routines/level3/xherk.cpp b/src/routines/level3/xherk.cpp index 6912d3a9..2e6f30ec 100644 --- a/src/routines/level3/xherk.cpp +++ b/src/routines/level3/xherk.cpp @@ -32,6 +32,7 @@ Xherk<T,U>::Xherk(Queue &queue, EventPointer event, const std::string &name): , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" }) { diff --git a/src/routines/level3/xsyrk.cpp b/src/routines/level3/xsyrk.cpp index 6bb2a24f..5ffdc028 100644 --- a/src/routines/level3/xsyrk.cpp +++ b/src/routines/level3/xsyrk.cpp @@ -32,6 +32,7 @@ Xsyrk<T>::Xsyrk(Queue &queue, EventPointer event, const std::string &name): , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" }) { diff --git a/src/routines/level3/xtrsm.cpp b/src/routines/level3/xtrsm.cpp index 905660ff..fe5d1e14 100644 --- a/src/routines/level3/xtrsm.cpp +++ b/src/routines/level3/xtrsm.cpp @@ -246,7 +246,7 @@ void Xtrsm<T>::TrsmColMajor(const Side side, const Triangle triangle, } // Retrieves the results - x_buffer.CopyTo(queue_, b_size, b_buffer); + x_buffer.CopyToAsync(queue_, b_size, b_buffer, event_); } // ================================================================================================= diff --git a/src/routines/level3/xtrsm.hpp b/src/routines/level3/xtrsm.hpp index 5b42398e..871d7253 100644 --- a/src/routines/level3/xtrsm.hpp +++ b/src/routines/level3/xtrsm.hpp @@ -31,6 +31,7 @@ class Xtrsm: public Xgemm<T> { using Xgemm<T>::device_; using Xgemm<T>::db_; using Xgemm<T>::program_; + using Xgemm<T>::event_; using Xgemm<T>::DoGemm; // Constructor diff --git a/src/routines/levelx/xgemmbatched.cpp b/src/routines/levelx/xgemmbatched.cpp index 2bbc5007..b12b8734 100644 --- a/src/routines/levelx/xgemmbatched.cpp +++ b/src/routines/levelx/xgemmbatched.cpp @@ -38,6 +38,7 @@ XgemmBatched<T>::XgemmBatched(Queue &queue, EventPointer event, const std::strin , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" , // separated in multiple parts to prevent C1091 in MSVC 2013 diff --git a/src/routines/levelx/xgemmstridedbatched.cpp b/src/routines/levelx/xgemmstridedbatched.cpp index 30c161cc..d9e3ebba 100644 --- a/src/routines/levelx/xgemmstridedbatched.cpp +++ b/src/routines/levelx/xgemmstridedbatched.cpp @@ -37,6 +37,7 @@ XgemmStridedBatched<T>::XgemmStridedBatched(Queue &queue, EventPointer event, co , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part1.opencl" #include "../../kernels/level3/xgemm_part2.opencl" + , // separated in multiple parts to prevent C1091 in MSVC 2013 #include "../../kernels/level3/xgemm_part3.opencl" #include "../../kernels/level3/xgemm_part4.opencl" , // separated in multiple parts to prevent C1091 in MSVC 2013 |