summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/kernels/common.opencl11
-rw-r--r--src/kernels/level2/xtrsv.opencl16
-rw-r--r--test/correctness/tester.cpp4
3 files changed, 25 insertions, 6 deletions
diff --git a/src/kernels/common.opencl b/src/kernels/common.opencl
index 8e59a5fe..c052e94f 100644
--- a/src/kernels/common.opencl
+++ b/src/kernels/common.opencl
@@ -214,13 +214,20 @@ R"(
#define MultiplySubtract(c, a, b) c -= a * b
#endif
-// The scalar division function
+// The scalar division function: real-value only
#if PRECISION == 3232 || PRECISION == 6464
- #define DivideReal(c, a, b) c.x = a.x / b.x; c.y = a.x
+ #define DivideReal(c, a, b) c.x = a.x / b.x; c.y = a.y
#else
#define DivideReal(c, a, b) c = a / b
#endif
+// The scalar division function: full division
+#if PRECISION == 3232 || PRECISION == 6464
+ #define DivideFull(c, a, b) singlereal num_x = (a.x * b.x) + (a.y * b.y); singlereal num_y = (a.y * b.x) - (a.x * b.y); singlereal denom = (b.x * b.x) + (b.y * b.y); c.x = num_x / denom; c.y = num_y / denom
+#else
+ #define DivideFull(c, a, b) c = a / b
+#endif
+
// The scalar AXPBY function
#if PRECISION == 3232 || PRECISION == 6464
#define AXPBY(e, a, b, c, d) e.x = MulReal(a,b) + MulReal(c,d); e.y = MulImag(a,b) + MulImag(c,d)
diff --git a/src/kernels/level2/xtrsv.opencl b/src/kernels/level2/xtrsv.opencl
index fd5de200..ebea77a3 100644
--- a/src/kernels/level2/xtrsv.opencl
+++ b/src/kernels/level2/xtrsv.opencl
@@ -55,7 +55,6 @@ void trsv_forward(int n,
if (is_transposed == 0) {
for (int i = 0; i < n; ++i) {
alm[i][tid] = A[i + tid*a_ld + a_offset];
- if (do_conjugate) { COMPLEX_CONJUGATE(alm[i][tid]); }
}
}
else {
@@ -63,6 +62,11 @@ void trsv_forward(int n,
alm[i][tid] = A[tid + i*a_ld + a_offset];
}
}
+ if (do_conjugate) {
+ for (int i = 0; i < n; ++i) {
+ COMPLEX_CONJUGATE(alm[i][tid]);
+ }
+ }
}
barrier(CLK_LOCAL_MEM_FENCE);
@@ -72,7 +76,7 @@ void trsv_forward(int n,
for (int j = 0; j < i; ++j) {
MultiplySubtract(xlm[i], alm[i][j], xlm[j]);
}
- if (is_unit_diagonal == 0) { DivideReal(xlm[i], xlm[i], alm[i][i]); }
+ if (is_unit_diagonal == 0) { DivideFull(xlm[i], xlm[i], alm[i][i]); }
}
}
barrier(CLK_LOCAL_MEM_FENCE);
@@ -99,7 +103,6 @@ void trsv_backward(int n,
if (is_transposed == 0) {
for (int i = 0; i < n; ++i) {
alm[i][tid] = A[i + tid*a_ld + a_offset];
- if (do_conjugate) { COMPLEX_CONJUGATE(alm[i][tid]); }
}
}
else {
@@ -107,6 +110,11 @@ void trsv_backward(int n,
alm[i][tid] = A[tid + i*a_ld + a_offset];
}
}
+ if (do_conjugate) {
+ for (int i = 0; i < n; ++i) {
+ COMPLEX_CONJUGATE(alm[i][tid]);
+ }
+ }
}
barrier(CLK_LOCAL_MEM_FENCE);
@@ -116,7 +124,7 @@ void trsv_backward(int n,
for (int j = i + 1; j < n; ++j) {
MultiplySubtract(xlm[i], alm[i][j], xlm[j]);
}
- if (is_unit_diagonal == 0) { DivideReal(xlm[i], xlm[i], alm[i][i]); }
+ if (is_unit_diagonal == 0) { DivideFull(xlm[i], xlm[i], alm[i][i]); }
}
}
barrier(CLK_LOCAL_MEM_FENCE);
diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp
index c449b09d..dc0f842e 100644
--- a/test/correctness/tester.cpp
+++ b/test/correctness/tester.cpp
@@ -410,6 +410,10 @@ bool TestSimilarityNear(const T val1, const T val2,
if (val1 == val2) {
return true;
}
+ // Handles cases with both results NaN
+ else if (std::isnan(val1) && std::isnan(val2)) {
+ return true;
+ }
// The values are zero or very small: the relative error is less meaningful
else if (val1 == 0 || val2 == 0 || difference < error_margin_absolute) {
return (difference < error_margin_absolute);