diff options
Diffstat (limited to 'src/kernels/level2/xtrsv.opencl')
-rw-r--r-- | src/kernels/level2/xtrsv.opencl | 86 |
1 files changed, 54 insertions, 32 deletions
diff --git a/src/kernels/level2/xtrsv.opencl b/src/kernels/level2/xtrsv.opencl index 01bd6ba5..fd5de200 100644 --- a/src/kernels/level2/xtrsv.opencl +++ b/src/kernels/level2/xtrsv.opencl @@ -41,67 +41,89 @@ void FillVector(const int n, const int inc, const int offset, __kernel __attribute__((reqd_work_group_size(TRSV_BLOCK_SIZE, 1, 1))) void trsv_forward(int n, - const __global real *A, const int a_offset, int lda, + const __global real *A, const int a_offset, int a_ld, __global real *b, const int b_offset, int b_inc, __global real *x, const int x_offset, int x_inc, - const int is_transposed, const int is_unit_diagonal) { - __local real sx[TRSV_BLOCK_SIZE]; + const int is_transposed, const int is_unit_diagonal, const int do_conjugate) { + __local real alm[TRSV_BLOCK_SIZE][TRSV_BLOCK_SIZE]; + __local real xlm[TRSV_BLOCK_SIZE]; const int tid = get_local_id(0); + + // Pre-loads the data into local memory if (tid < n) { - sx[tid] = b[tid*b_inc + b_offset]; + Subtract(xlm[tid], b[tid*b_inc + b_offset], x[tid*x_inc + x_offset]); + if (is_transposed == 0) { + for (int i = 0; i < n; ++i) { + alm[i][tid] = A[i + tid*a_ld + a_offset]; + if (do_conjugate) { COMPLEX_CONJUGATE(alm[i][tid]); } + } + } + else { + for (int i = 0; i < n; ++i) { + alm[i][tid] = A[tid + i*a_ld + a_offset]; + } + } } barrier(CLK_LOCAL_MEM_FENCE); - for (int i = 0; i < n; ++i) { - if (tid == 0) { - real sum = sx[i]; + + // Computes the result (single-threaded for now) + if (tid == 0) { + for (int i = 0; i < n; ++i) { for (int j = 0; j < i; ++j) { - real a_value; - if (is_transposed == 0) { a_value = A[i + j*lda + a_offset]; } - else { a_value = A[j + i*lda + a_offset]; } - sum -= a_value * sx[j]; + MultiplySubtract(xlm[i], alm[i][j], xlm[j]); } - sum -= x[i*x_inc + x_offset]; - if (is_unit_diagonal == 0) { sum /= A[i + i*lda + a_offset]; } - sx[i] = sum; + if (is_unit_diagonal == 0) { DivideReal(xlm[i], xlm[i], alm[i][i]); } } - barrier(CLK_LOCAL_MEM_FENCE); } barrier(CLK_LOCAL_MEM_FENCE); + + // Stores the results if (tid < n) { - x[tid*x_inc + x_offset] = sx[tid]; + x[tid*x_inc + x_offset] = xlm[tid]; } } __kernel __attribute__((reqd_work_group_size(TRSV_BLOCK_SIZE, 1, 1))) void trsv_backward(int n, - const __global real *A, const int a_offset, int lda, + const __global real *A, const int a_offset, int a_ld, __global real *b, const int b_offset, int b_inc, __global real *x, const int x_offset, int x_inc, - const int is_trans, const int is_unit_diagonal) { - __local real sx[TRSV_BLOCK_SIZE]; + const int is_transposed, const int is_unit_diagonal, const int do_conjugate) { + __local real alm[TRSV_BLOCK_SIZE][TRSV_BLOCK_SIZE]; + __local real xlm[TRSV_BLOCK_SIZE]; const int tid = get_local_id(0); + + // Pre-loads the data into local memory if (tid < n) { - sx[tid] = b[tid*b_inc + b_offset]; + Subtract(xlm[tid], b[tid*b_inc + b_offset], x[tid*x_inc + x_offset]); + if (is_transposed == 0) { + for (int i = 0; i < n; ++i) { + alm[i][tid] = A[i + tid*a_ld + a_offset]; + if (do_conjugate) { COMPLEX_CONJUGATE(alm[i][tid]); } + } + } + else { + for (int i = 0; i < n; ++i) { + alm[i][tid] = A[tid + i*a_ld + a_offset]; + } + } } barrier(CLK_LOCAL_MEM_FENCE); - for (int i = n - 1; i >= 0; --i) { - if (tid == 0) { - real sum = sx[i]; + + // Computes the result (single-threaded for now) + if (tid == 0) { + for (int i = n - 1; i >= 0; --i) { for (int j = i + 1; j < n; ++j) { - real a_value; - if (is_trans == 0) { a_value = A[i + j*lda + a_offset]; } - else { a_value = A[j + i*lda + a_offset]; } - sum -= a_value * sx[j]; + MultiplySubtract(xlm[i], alm[i][j], xlm[j]); } - sum -= x[i*x_inc + x_offset]; - if (is_unit_diagonal == 0) { sum /= A[i + i*lda + a_offset]; } - sx[i] = sum; + if (is_unit_diagonal == 0) { DivideReal(xlm[i], xlm[i], alm[i][i]); } } - barrier(CLK_LOCAL_MEM_FENCE); } barrier(CLK_LOCAL_MEM_FENCE); + + // Stores the results if (tid < n) { - x[tid*x_inc + x_offset] = sx[tid]; + x[tid*x_inc + x_offset] = xlm[tid]; } } |