From f2ba75890c522b4fe1762bfeac3e08667cf9588a Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Thu, 12 May 2016 19:56:21 +0200 Subject: Initial changes in preparation for half-precision fp16 support --- include/internal/tuning.h | 2 + include/internal/utilities.h | 3 ++ scripts/generator/datatype.py | 3 ++ scripts/generator/generator.py | 114 +++++++++++++++++++++-------------------- scripts/generator/routine.py | 6 +++ src/kernels/common.opencl | 17 ++++-- src/routine.cc | 1 + src/utilities.cc | 20 ++++++++ 8 files changed, 108 insertions(+), 58 deletions(-) diff --git a/include/internal/tuning.h b/include/internal/tuning.h index 5645a5e5..6ba1db61 100644 --- a/include/internal/tuning.h +++ b/include/internal/tuning.h @@ -20,6 +20,8 @@ #include +#include "internal/utilities.h" + namespace clblast { // ================================================================================================= diff --git a/include/internal/utilities.h b/include/internal/utilities.h index 82cd7f44..46d9b8f1 100644 --- a/include/internal/utilities.h +++ b/include/internal/utilities.h @@ -27,6 +27,9 @@ namespace clblast { // ================================================================================================= +// Host data-type for half-precision floating-point (16-bit) +using half = cl_half; + // Shorthands for complex data-types using float2 = std::complex; using double2 = std::complex; diff --git a/scripts/generator/datatype.py b/scripts/generator/datatype.py index 5a58ab53..5bff95d1 100644 --- a/scripts/generator/datatype.py +++ b/scripts/generator/datatype.py @@ -13,10 +13,13 @@ # ================================================================================================== # Short-hands for data-types +HLF = "half" FLT = "float" DBL = "double" FLT2 = "float2" DBL2 = "double2" + +HCL = "cl_half" F2CL = "cl_float2" D2CL = "cl_double2" diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 210f371f..bc8fa783 100644 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -28,11 +28,12 @@ import os.path # Local files from routine import Routine -from datatype import DataType, FLT, DBL, FLT2, DBL2, F2CL, D2CL +from datatype import DataType, HLF, FLT, DBL, FLT2, DBL2, HCL, F2CL, D2CL # ================================================================================================== # Regular data-types +H = DataType("H", "H", HLF, [HLF, HLF, HCL, HCL], HLF ) # half (16) S = DataType("S", "S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32) D = DataType("D", "D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64) C = DataType("C", "C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232) @@ -67,7 +68,7 @@ routines = [ Routine(True, True, "1", "swap", T, [S,D,C,Z], ["n"], [], [], ["x","y"], [], "", "Swap two vectors", "Interchanges the contents of vectors x and y.", []), Routine(True, True, "1", "scal", T, [S,D,C,Z], ["n"], [], [], ["x"], ["alpha"], "", "Vector scaling", "Multiplies all elements of vector x by a scalar constant alpha.", []), Routine(True, True, "1", "copy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], [], "", "Vector copy", "Copies the contents of vector x into vector y.", []), - Routine(True, True, "1", "axpy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []), + Routine(True, True, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []), Routine(True, True, "1", "dot", T, [S,D], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two vectors", "Multiplies the vectors x and y element-wise and accumulates the results. The sum is stored in the dot buffer.", []), Routine(True, True, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []), Routine(True, True, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []), @@ -229,22 +230,23 @@ def wrapper_clblas(routines): result = "" for routine in routines: if routine.has_tests: - result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNames()) + result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNamesTested()) if routine.NoScalars(): result += routine.RoutineHeaderWrapperCL(routine.template, True, 21)+";\n" for flavour in routine.flavours: - indent = " "*(17 + routine.Length()) - result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n" - arguments = routine.ArgumentsWrapperCL(flavour) - if routine.scratch: - result += " auto queue = Queue(queues[0]);\n" - result += " auto context = queue.GetContext();\n" - result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n" - arguments += ["scratch_buffer()"] - result += " return clblas"+flavour.name+routine.name+"(" - result += (",\n"+indent).join([a for a in arguments]) - result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);" - result += "\n}\n" + if flavour.precision_name in ["S","D","C","Z"]: + indent = " "*(17 + routine.Length()) + result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n" + arguments = routine.ArgumentsWrapperCL(flavour) + if routine.scratch: + result += " auto queue = Queue(queues[0]);\n" + result += " auto context = queue.GetContext();\n" + result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n" + arguments += ["scratch_buffer()"] + result += " return clblas"+flavour.name+routine.name+"(" + result += (",\n"+indent).join([a for a in arguments]) + result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);" + result += "\n}\n" return result # The wrapper to the reference CBLAS routines (for performance/correctness testing) @@ -252,44 +254,45 @@ def wrapper_cblas(routines): result = "" for routine in routines: if routine.has_tests: - result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNames()) + result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNamesTested()) for flavour in routine.flavours: - indent = " "*(10 + routine.Length()) - result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n" - arguments = routine.ArgumentsWrapperC(flavour) - - # Double-precision scalars - for scalar in routine.scalars: - if flavour.IsComplex(scalar): - result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n" - - # Special case for scalar outputs - assignment = "" - postfix = "" - endofline = "" - extra_argument = "" - for output_buffer in routine.outputs: - if output_buffer in routine.ScalarBuffersFirst(): - if flavour in [C,Z]: - postfix += "_sub" - indent += " " - extra_argument += ",\n"+indent+"reinterpret_cast(&"+output_buffer+"_buffer["+output_buffer+"_offset])" - elif output_buffer in routine.IndexBuffers(): - assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = " - indent += " "*len(assignment) - else: - assignment = output_buffer+"_buffer["+output_buffer+"_offset]" - if (flavour.name in ["Sc","Dz"]): - assignment = assignment+".real(" - endofline += ")" + if flavour.precision_name in ["S","D","C","Z"]: + indent = " "*(10 + routine.Length()) + result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n" + arguments = routine.ArgumentsWrapperC(flavour) + + # Double-precision scalars + for scalar in routine.scalars: + if flavour.IsComplex(scalar): + result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n" + + # Special case for scalar outputs + assignment = "" + postfix = "" + endofline = "" + extra_argument = "" + for output_buffer in routine.outputs: + if output_buffer in routine.ScalarBuffersFirst(): + if flavour in [C,Z]: + postfix += "_sub" + indent += " " + extra_argument += ",\n"+indent+"reinterpret_cast(&"+output_buffer+"_buffer["+output_buffer+"_offset])" + elif output_buffer in routine.IndexBuffers(): + assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = " + indent += " "*len(assignment) else: - assignment = assignment+" = " - indent += " "*len(assignment) - - result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"(" - result += (",\n"+indent).join([a for a in arguments]) - result += extra_argument+endofline+");" - result += "\n}\n" + assignment = output_buffer+"_buffer["+output_buffer+"_offset]" + if (flavour.name in ["Sc","Dz"]): + assignment = assignment+".real(" + endofline += ")" + else: + assignment = assignment+" = " + indent += " "*len(assignment) + + result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"(" + result += (",\n"+indent).join([a for a in arguments]) + result += extra_argument+endofline+");" + result += "\n}\n" return result # ================================================================================================== @@ -368,9 +371,10 @@ for level in [1,2,3]: body += "int main(int argc, char *argv[]) {\n" not_first = "false" for flavour in routine.flavours: - body += " clblast::RunTests::PadCopyTransposeMatrix(EventPointer event, std::vector; template class Routine; template class Routine; template class Routine; diff --git a/src/utilities.cc b/src/utilities.cc index 68a4f02a..5325c3e8 100644 --- a/src/utilities.cc +++ b/src/utilities.cc @@ -29,6 +29,7 @@ std::string ToString(T value) { } template std::string ToString(int value); template std::string ToString(size_t value); +template std::string ToString(half value); template std::string ToString(float value); template std::string ToString(double value); @@ -105,6 +106,9 @@ template T ConvertArgument(const char* value) { return static_cast(std::stoi(value)); } +template <> half ConvertArgument(const char* value) { + return static_cast(std::stod(value)); +} template <> float ConvertArgument(const char* value) { return static_cast(std::stod(value)); } @@ -147,6 +151,7 @@ T GetArgument(const int argc, char *argv[], std::string &help, // Compiles the above function template int GetArgument(const int, char **, std::string&, const std::string&, const int); template size_t GetArgument(const int, char **, std::string&, const std::string&, const size_t); +template half GetArgument(const int, char **, std::string&, const std::string&, const half); template float GetArgument(const int, char **, std::string&, const std::string&, const float); template double GetArgument(const int, char **, std::string&, const std::string&, const double); template float2 GetArgument(const int, char **, std::string&, const std::string&, const float2); @@ -227,6 +232,16 @@ void PopulateVector(std::vector &vector) { for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); } } +// Specialized versions of the above for half-precision +template <> +void PopulateVector(std::vector &vector) { + auto lower_limit = static_cast(kTestDataLowerLimit); + auto upper_limit = static_cast(kTestDataUpperLimit); + std::mt19937 mt(GetRandomSeed()); + std::uniform_real_distribution dist(lower_limit, upper_limit); + for (auto &element: vector) { element = static_cast(dist(mt)); } +} + // ================================================================================================= // Returns a scalar with a default value @@ -234,6 +249,7 @@ template T GetScalar() { return static_cast(2.0); } +template half GetScalar(); template float GetScalar(); template double GetScalar(); @@ -288,6 +304,10 @@ template <> bool PrecisionSupported(const Device &device) { auto extensions = device.Capabilities(); return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true; } +template <> bool PrecisionSupported(const Device &device) { + auto extensions = device.Capabilities(); + return (extensions.find(kKhronosHalfPrecision) == std::string::npos) ? false : true; +} // ================================================================================================= } // namespace clblast -- cgit v1.2.3