From 52dd7433caac3f30b6c02ed299ec1b16dc7614ea Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Thu, 6 Apr 2017 20:56:28 +0200 Subject: Completed the cuBLAS wrapper --- scripts/generator/generator.py | 2 +- scripts/generator/generator/convert.py | 2 +- scripts/generator/generator/cpp.py | 16 ++++++++- scripts/generator/generator/datatype.py | 5 +++ scripts/generator/generator/routine.py | 57 +++++++++++++++++++++------------ 5 files changed, 59 insertions(+), 23 deletions(-) (limited to 'scripts') diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 3f3fab62..8810397c 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -42,7 +42,7 @@ FILES = [ "/include/clblast_netlib_c.h", "/src/clblast_netlib_c.cpp", ] -HEADER_LINES = [123, 76, 126, 23, 29, 41, 29, 65, 32] +HEADER_LINES = [122, 77, 126, 23, 29, 41, 29, 65, 32] FOOTER_LINES = [25, 138, 27, 38, 6, 6, 6, 9, 2] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 63 diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py index 80b6f338..07f45669 100644 --- a/scripts/generator/generator/convert.py +++ b/scripts/generator/generator/convert.py @@ -59,7 +59,7 @@ def option_to_cblas(x): def option_to_cublas(x): """As above, but for clBLAS data-types""" return { - 'layout': "cublas_has_no_layout", + 'layout': "Layout", 'a_transpose': "cublasOperation_t", 'b_transpose': "cublasOperation_t", 'ab_transpose': "cublasOperation_t", diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 49240095..7c695dc8 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -304,8 +304,22 @@ def wrapper_cublas(routine): if flavour.precision_name in ["S", "D", "C", "Z"]: indent = " " * (24 + routine.length()) arguments = routine.arguments_wrapper_cublas(flavour) + + # Handles row-major + if routine.has_layout(): + result += " if (layout == Layout::kRowMajor) { return CUBLAS_STATUS_NOT_SUPPORTED; }" + NL + + # Complex scalars + for scalar in routine.scalars: + if flavour.is_complex(scalar): + cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex" + result += " " + cuda_complex + " " + scalar + "_cuda;" + NL + result += " " + scalar + "_cuda.x = " + scalar + ".real();" + NL + result += " " + scalar + "_cuda.y = " + scalar + ".imag();" + NL + + # Calls the cuBLAS routine result += " cublasHandle_t handle;" + NL - result += " auto status = cublas" + flavour.name + routine.name + "(handle, " + result += " auto status = cublas" + flavour.name_cublas() + routine.name + "(handle, " result += ("," + NL + indent).join([a for a in arguments]) + ");" + NL result += " cublasDestroy(handle);" + NL result += " return status;" diff --git a/scripts/generator/generator/datatype.py b/scripts/generator/generator/datatype.py index cab2411a..6ac5681a 100644 --- a/scripts/generator/generator/datatype.py +++ b/scripts/generator/generator/datatype.py @@ -87,6 +87,11 @@ class DataType: """Current type is of a non-standard type""" return self.buffer_type in [D_HALF, D_FLOAT2, D_DOUBLE2] + def name_cublas(self): + if "i" in self.name: + return "I" + self.name[1].lower() + return self.name + # Regular data-types H = DataType("H", "H", D_HALF, [D_HALF] * 2 + [D_HALF_OPENCL] * 2, D_HALF) # half (16) diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 9414eb50..b1db484f 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -197,6 +197,10 @@ class Routine: """Determines whether or not this routine has scalar arguments (alpha/beta)""" return self.scalars == [] + def has_layout(self): + """Determines whether the layout is an argument""" + return "layout" in self.options + def short_names(self): """Returns the upper-case names of these routines (all flavours)""" return "/".join([f.name + self.upper_name() for f in self.flavours]) @@ -339,10 +343,16 @@ class Routine: return [", ".join(a + c)] return [] - def buffer_wrapper_cublas(self, name): + def buffer_wrapper_cublas(self, name, flavour): """As above but for cuBLAS the wrapper""" + prefix = "const " if name in self.inputs else "" if name in self.inputs or name in self.outputs: - a = ["&" + name + "_buffer[" + name + "_offset]"] + if flavour.precision_name in ["C", "Z"]: + cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex" + a = ["reinterpret_cast<" + prefix + cuda_complex + "*>" + + "(&" + name + "_buffer[" + name + "_offset])"] + else: + a = ["&" + name + "_buffer[" + name + "_offset]"] c = [] if name in ["x", "y"]: c = ["static_cast(" + name + "_" + self.postfix(name) + ")"] @@ -421,16 +431,6 @@ class Routine: return [name] return [] - def scalar_use_wrapper_by_ref(self, name, flavour): - """As above, but for the cuBLAS wrapper""" - if name in self.scalars: - if name == "alpha": - return ["&" + flavour.use_alpha_opencl()] - elif name == "beta": - return ["&" + flavour.use_beta_opencl()] - return [name] - return [] - def scalar_use_wrapper_cblas(self, name, flavour): """As above, but for the CBLAS wrapper""" if name in self.scalars: @@ -439,6 +439,14 @@ class Routine: return [name] return [] + def scalar_use_wrapper_cublas(self, name, flavour): + """As above, but for the cuBLAS wrapper""" + if name in self.scalars: + if flavour.is_complex(name): + return ["&" + name + "_cuda"] + return ["&" + name] + return [] + def scalar_def(self, name, flavour): """Retrieves the definition of a scalar (alpha/beta)""" if name in self.scalars: @@ -534,6 +542,15 @@ class Routine: return [", ".join(self.options)] return [] + def options_list_no_layout(self): + """Retrieves a list of options""" + options = self.options[:] + if "layout" in options: + options.remove("layout") + if options: + return [", ".join(options)] + return [] + def options_cast(self, indent): """As above, but now casted to CLBlast data-types""" if self.options: @@ -670,14 +687,14 @@ class Routine: def arguments_wrapper_cublas(self, flavour): """As above, but for the cuBLAS wrapper""" - return (self.options_list() + self.sizes_list_as_int() + - list(chain(*[self.buffer_wrapper_cublas(b) for b in self.scalar_buffers_first()])) + - self.scalar_use_wrapper_by_ref("alpha", flavour) + - list(chain(*[self.buffer_wrapper_cublas(b) for b in self.buffers_first()])) + - self.scalar_use_wrapper_by_ref("beta", flavour) + - list(chain(*[self.buffer_wrapper_cublas(b) for b in self.buffers_second()])) + - list(chain(*[self.buffer_wrapper_cublas(b) for b in self.scalar_buffers_second()])) + - list(chain(*[self.scalar_use_wrapper_by_ref(s, flavour) for s in self.other_scalars()]))) + return (self.options_list_no_layout() + self.sizes_list_as_int() + + self.scalar_use_wrapper_cublas("alpha", flavour) + + list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_first()])) + + self.scalar_use_wrapper_cublas("beta", flavour) + + list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_second()])) + + list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_first()])) + + list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_second()])) + + list(chain(*[self.scalar_use_wrapper_cublas(s, flavour) for s in self.other_scalars()]))) def arguments_def(self, flavour): """Retrieves a combination of all the argument definitions""" -- cgit v1.2.3