diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2017-02-05 14:36:31 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2017-02-05 14:36:31 +0100 |
commit | e7cbb5915aef16f3a64566292459eaede5a600e5 (patch) | |
tree | 8a18e7018b1922d9d445eede6af7d5140f33dc71 | |
parent | c209dd7af90d604c8210cc5680b6c7a50b2b995f (diff) |
Fixed complex version of the TRSV kernel
-rw-r--r-- | src/kernels/common.opencl | 11 | ||||
-rw-r--r-- | src/kernels/level2/xtrsv.opencl | 16 | ||||
-rw-r--r-- | test/correctness/tester.cpp | 4 |
3 files changed, 25 insertions, 6 deletions
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl index 8e59a5fe..c052e94f 100644 --- a/src/kernels/common.opencl +++ b/src/kernels/common.opencl @@ -214,13 +214,20 @@ R"( #define MultiplySubtract(c, a, b) c -= a * b #endif -// The scalar division function +// The scalar division function: real-value only #if PRECISION == 3232 || PRECISION == 6464 - #define DivideReal(c, a, b) c.x = a.x / b.x; c.y = a.x + #define DivideReal(c, a, b) c.x = a.x / b.x; c.y = a.y #else #define DivideReal(c, a, b) c = a / b #endif +// The scalar division function: full division +#if PRECISION == 3232 || PRECISION == 6464 + #define DivideFull(c, a, b) singlereal num_x = (a.x * b.x) + (a.y * b.y); singlereal num_y = (a.y * b.x) - (a.x * b.y); singlereal denom = (b.x * b.x) + (b.y * b.y); c.x = num_x / denom; c.y = num_y / denom +#else + #define DivideFull(c, a, b) c = a / b +#endif + // The scalar AXPBY function #if PRECISION == 3232 || PRECISION == 6464 #define AXPBY(e, a, b, c, d) e.x = MulReal(a,b) + MulReal(c,d); e.y = MulImag(a,b) + MulImag(c,d) diff --git a/src/kernels/level2/xtrsv.opencl b/src/kernels/level2/xtrsv.opencl index fd5de200..ebea77a3 100644 --- a/src/kernels/level2/xtrsv.opencl +++ b/src/kernels/level2/xtrsv.opencl @@ -55,7 +55,6 @@ void trsv_forward(int n, if (is_transposed == 0) { for (int i = 0; i < n; ++i) { alm[i][tid] = A[i + tid*a_ld + a_offset]; - if (do_conjugate) { COMPLEX_CONJUGATE(alm[i][tid]); } } } else { @@ -63,6 +62,11 @@ void trsv_forward(int n, alm[i][tid] = A[tid + i*a_ld + a_offset]; } } + if (do_conjugate) { + for (int i = 0; i < n; ++i) { + COMPLEX_CONJUGATE(alm[i][tid]); + } + } } barrier(CLK_LOCAL_MEM_FENCE); @@ -72,7 +76,7 @@ void trsv_forward(int n, for (int j = 0; j < i; ++j) { MultiplySubtract(xlm[i], alm[i][j], xlm[j]); } - if (is_unit_diagonal == 0) { DivideReal(xlm[i], xlm[i], alm[i][i]); } + if (is_unit_diagonal == 0) { DivideFull(xlm[i], xlm[i], alm[i][i]); } } } barrier(CLK_LOCAL_MEM_FENCE); @@ -99,7 +103,6 @@ void trsv_backward(int n, if (is_transposed == 0) { for (int i = 0; i < n; ++i) { alm[i][tid] = A[i + tid*a_ld + a_offset]; - if (do_conjugate) { COMPLEX_CONJUGATE(alm[i][tid]); } } } else { @@ -107,6 +110,11 @@ void trsv_backward(int n, alm[i][tid] = A[tid + i*a_ld + a_offset]; } } + if (do_conjugate) { + for (int i = 0; i < n; ++i) { + COMPLEX_CONJUGATE(alm[i][tid]); + } + } } barrier(CLK_LOCAL_MEM_FENCE); @@ -116,7 +124,7 @@ void trsv_backward(int n, for (int j = i + 1; j < n; ++j) { MultiplySubtract(xlm[i], alm[i][j], xlm[j]); } - if (is_unit_diagonal == 0) { DivideReal(xlm[i], xlm[i], alm[i][i]); } + if (is_unit_diagonal == 0) { DivideFull(xlm[i], xlm[i], alm[i][i]); } } } barrier(CLK_LOCAL_MEM_FENCE); diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp index c449b09d..dc0f842e 100644 --- a/test/correctness/tester.cpp +++ b/test/correctness/tester.cpp @@ -410,6 +410,10 @@ bool TestSimilarityNear(const T val1, const T val2, if (val1 == val2) { return true; } + // Handles cases with both results NaN + else if (std::isnan(val1) && std::isnan(val2)) { + return true; + } // The values are zero or very small: the relative error is less meaningful else if (val1 == 0 || val2 == 0 || difference < error_margin_absolute) { return (difference < error_margin_absolute); |