summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-03-04 15:21:33 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2017-03-04 15:21:33 +0100
commite993ee077b50d3a6134309d465a4174b5c749596 (patch)
treeb967f2702b90d8080a3e3cb41b9cbc01ab9eddc3 /src
parent3fc73851f7ed885335940eb85e53069638567323 (diff)
Added a proper data-preparation function for the TRSM tests
Diffstat (limited to 'src')
-rw-r--r--src/kernels/level3/invert_diagonal_blocks.opencl4
-rw-r--r--src/utilities/utilities.cpp36
-rw-r--r--src/utilities/utilities.hpp13
3 files changed, 34 insertions, 19 deletions
diff --git a/src/kernels/level3/invert_diagonal_blocks.opencl b/src/kernels/level3/invert_diagonal_blocks.opencl
index c59bcbcb..55f4a963 100644
--- a/src/kernels/level3/invert_diagonal_blocks.opencl
+++ b/src/kernels/level3/invert_diagonal_blocks.opencl
@@ -140,7 +140,9 @@ void InvertDiagonalBlock(int n, __global const real* restrict src, const int src
for (int k = j + 1; k < INTERNAL_BLOCK_SIZE; ++k) {
MultiplyAdd(sum, lm[thread_index][k], lm[k][j]);
}
- Multiply(lm[thread_index][j], -lm[j][j], sum);
+ real diagonal_value = lm[j][j];
+ Negate(diagonal_value);
+ Multiply(lm[thread_index][j], diagonal_value, sum);
}
barrier(CLK_LOCAL_MEM_FENCE);
}
diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp
index 9cf75490..d68cc1a6 100644
--- a/src/utilities/utilities.cpp
+++ b/src/utilities/utilities.cpp
@@ -55,29 +55,35 @@ template <> half ConstantNegOne() { return FloatToHalf(-1.0f); }
template <> float2 ConstantNegOne() { return {-1.0f, 0.0f}; }
template <> double2 ConstantNegOne() { return {-1.0, 0.0}; }
-// Returns a scalar of value 1
-template <typename T> T ConstantTwo() { return static_cast<T>(2.0); }
-template float ConstantTwo<float>();
-template double ConstantTwo<double>();
-template <> half ConstantTwo() { return FloatToHalf(2.0f); }
-template <> float2 ConstantTwo() { return {2.0f, 0.0f}; }
-template <> double2 ConstantTwo() { return {2.0, 0.0}; }
+// Returns a scalar of some value
+template <typename T> T Constant(const double val) { return static_cast<T>(val); }
+template float Constant<float>(const double);
+template double Constant<double>(const double);
+template <> half Constant(const double val) { return FloatToHalf(static_cast<float>(val)); }
+template <> float2 Constant(const double val) { return {static_cast<float>(val), 0.0f}; }
+template <> double2 Constant(const double val) { return {val, 0.0}; }
// Returns a small scalar value just larger than 0
-template <typename T> T SmallConstant() { return static_cast<T>(1e7); }
+template <typename T> T SmallConstant() { return static_cast<T>(1e-4); }
template float SmallConstant<float>();
template double SmallConstant<double>();
-template <> half SmallConstant() { return FloatToHalf(1e7); }
-template <> float2 SmallConstant() { return {1e7, 0.0f}; }
-template <> double2 SmallConstant() { return {1e7, 0.0}; }
+template <> half SmallConstant() { return FloatToHalf(1e-4); }
+template <> float2 SmallConstant() { return {1e-4, 0.0f}; }
+template <> double2 SmallConstant() { return {1e-4, 0.0}; }
-// Returns the absolute value of a scalar
-template <typename T> T AbsoluteValue(const T value) { return std::fabs(value); }
+// Returns the absolute value of a scalar (modulus in case of a complex number)
+template <typename T> typename BaseType<T>::Type AbsoluteValue(const T value) { return std::fabs(value); }
template float AbsoluteValue<float>(const float);
template double AbsoluteValue<double>(const double);
template <> half AbsoluteValue(const half value) { return FloatToHalf(std::fabs(HalfToFloat(value))); }
-template <> float2 AbsoluteValue(const float2 value) { return std::abs(value); }
-template <> double2 AbsoluteValue(const double2 value) { return std::abs(value); }
+template <> float AbsoluteValue(const float2 value) {
+ if (value.real() == 0.0f && value.imag() == 0.0f) { return 0.0f; }
+ return std::sqrt(value.real() * value.real() + value.imag() * value.imag());
+}
+template <> double AbsoluteValue(const double2 value) {
+ if (value.real() == 0.0 && value.imag() == 0.0) { return 0.0; }
+ return std::sqrt(value.real() * value.real() + value.imag() * value.imag());
+}
// Returns whether a scalar is close to zero
template <typename T> bool IsCloseToZero(const T value) { return (value > -SmallConstant<T>()) && (value < SmallConstant<T>()); }
diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp
index 044955ea..3c9be6a2 100644
--- a/src/utilities/utilities.hpp
+++ b/src/utilities/utilities.hpp
@@ -98,6 +98,13 @@ constexpr auto kArgNoAbbreviations = "no_abbrv";
// =================================================================================================
+// Converts a regular or complex type to it's base type (e.g. float2 to float)
+template <typename T> struct BaseType { using Type = T; };
+template <> struct BaseType<float2> { using Type = float; };
+template <> struct BaseType<double2> { using Type = double; };
+
+// =================================================================================================
+
// Returns a scalar with a default value
template <typename T> T GetScalar();
@@ -105,11 +112,11 @@ template <typename T> T GetScalar();
template <typename T> T ConstantZero();
template <typename T> T ConstantOne();
template <typename T> T ConstantNegOne();
-template <typename T> T ConstantTwo();
+template <typename T> T Constant(const double val);
template <typename T> T SmallConstant();
-// Returns the absolute value of a scalar
-template <typename T> T AbsoluteValue(const T value);
+// Returns the absolute value of a scalar (modulus in case of complex numbers)
+template <typename T> typename BaseType<T>::Type AbsoluteValue(const T value);
// Returns whether a scalar is close to zero
template <typename T> bool IsCloseToZero(const T value);