summaryrefslogtreecommitdiff
path: root/scripts/generator
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-04-06 20:56:28 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-04-06 20:56:28 +0200
commit52dd7433caac3f30b6c02ed299ec1b16dc7614ea (patch)
tree3ed2be5bdddf948033dd03c0b1e4b6759ed11d69 /scripts/generator
parentdbe22b5bf3da02a2d94280361cddde1f8f66b63f (diff)
Completed the cuBLAS wrapper
Diffstat (limited to 'scripts/generator')
-rwxr-xr-xscripts/generator/generator.py2
-rw-r--r--scripts/generator/generator/convert.py2
-rw-r--r--scripts/generator/generator/cpp.py16
-rw-r--r--scripts/generator/generator/datatype.py5
-rw-r--r--scripts/generator/generator/routine.py57
5 files changed, 59 insertions, 23 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 3f3fab62..8810397c 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -42,7 +42,7 @@ FILES = [
"/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp",
]
-HEADER_LINES = [123, 76, 126, 23, 29, 41, 29, 65, 32]
+HEADER_LINES = [122, 77, 126, 23, 29, 41, 29, 65, 32]
FOOTER_LINES = [25, 138, 27, 38, 6, 6, 6, 9, 2]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py
index 80b6f338..07f45669 100644
--- a/scripts/generator/generator/convert.py
+++ b/scripts/generator/generator/convert.py
@@ -59,7 +59,7 @@ def option_to_cblas(x):
def option_to_cublas(x):
"""As above, but for clBLAS data-types"""
return {
- 'layout': "cublas_has_no_layout",
+ 'layout': "Layout",
'a_transpose': "cublasOperation_t",
'b_transpose': "cublasOperation_t",
'ab_transpose': "cublasOperation_t",
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index 49240095..7c695dc8 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -304,8 +304,22 @@ def wrapper_cublas(routine):
if flavour.precision_name in ["S", "D", "C", "Z"]:
indent = " " * (24 + routine.length())
arguments = routine.arguments_wrapper_cublas(flavour)
+
+ # Handles row-major
+ if routine.has_layout():
+ result += " if (layout == Layout::kRowMajor) { return CUBLAS_STATUS_NOT_SUPPORTED; }" + NL
+
+ # Complex scalars
+ for scalar in routine.scalars:
+ if flavour.is_complex(scalar):
+ cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex"
+ result += " " + cuda_complex + " " + scalar + "_cuda;" + NL
+ result += " " + scalar + "_cuda.x = " + scalar + ".real();" + NL
+ result += " " + scalar + "_cuda.y = " + scalar + ".imag();" + NL
+
+ # Calls the cuBLAS routine
result += " cublasHandle_t handle;" + NL
- result += " auto status = cublas" + flavour.name + routine.name + "(handle, "
+ result += " auto status = cublas" + flavour.name_cublas() + routine.name + "(handle, "
result += ("," + NL + indent).join([a for a in arguments]) + ");" + NL
result += " cublasDestroy(handle);" + NL
result += " return status;"
diff --git a/scripts/generator/generator/datatype.py b/scripts/generator/generator/datatype.py
index cab2411a..6ac5681a 100644
--- a/scripts/generator/generator/datatype.py
+++ b/scripts/generator/generator/datatype.py
@@ -87,6 +87,11 @@ class DataType:
"""Current type is of a non-standard type"""
return self.buffer_type in [D_HALF, D_FLOAT2, D_DOUBLE2]
+ def name_cublas(self):
+ if "i" in self.name:
+ return "I" + self.name[1].lower()
+ return self.name
+
# Regular data-types
H = DataType("H", "H", D_HALF, [D_HALF] * 2 + [D_HALF_OPENCL] * 2, D_HALF) # half (16)
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 9414eb50..b1db484f 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -197,6 +197,10 @@ class Routine:
"""Determines whether or not this routine has scalar arguments (alpha/beta)"""
return self.scalars == []
+ def has_layout(self):
+ """Determines whether the layout is an argument"""
+ return "layout" in self.options
+
def short_names(self):
"""Returns the upper-case names of these routines (all flavours)"""
return "/".join([f.name + self.upper_name() for f in self.flavours])
@@ -339,10 +343,16 @@ class Routine:
return [", ".join(a + c)]
return []
- def buffer_wrapper_cublas(self, name):
+ def buffer_wrapper_cublas(self, name, flavour):
"""As above but for cuBLAS the wrapper"""
+ prefix = "const " if name in self.inputs else ""
if name in self.inputs or name in self.outputs:
- a = ["&" + name + "_buffer[" + name + "_offset]"]
+ if flavour.precision_name in ["C", "Z"]:
+ cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex"
+ a = ["reinterpret_cast<" + prefix + cuda_complex + "*>" +
+ "(&" + name + "_buffer[" + name + "_offset])"]
+ else:
+ a = ["&" + name + "_buffer[" + name + "_offset]"]
c = []
if name in ["x", "y"]:
c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"]
@@ -421,16 +431,6 @@ class Routine:
return [name]
return []
- def scalar_use_wrapper_by_ref(self, name, flavour):
- """As above, but for the cuBLAS wrapper"""
- if name in self.scalars:
- if name == "alpha":
- return ["&" + flavour.use_alpha_opencl()]
- elif name == "beta":
- return ["&" + flavour.use_beta_opencl()]
- return [name]
- return []
-
def scalar_use_wrapper_cblas(self, name, flavour):
"""As above, but for the CBLAS wrapper"""
if name in self.scalars:
@@ -439,6 +439,14 @@ class Routine:
return [name]
return []
+ def scalar_use_wrapper_cublas(self, name, flavour):
+ """As above, but for the cuBLAS wrapper"""
+ if name in self.scalars:
+ if flavour.is_complex(name):
+ return ["&" + name + "_cuda"]
+ return ["&" + name]
+ return []
+
def scalar_def(self, name, flavour):
"""Retrieves the definition of a scalar (alpha/beta)"""
if name in self.scalars:
@@ -534,6 +542,15 @@ class Routine:
return [", ".join(self.options)]
return []
+ def options_list_no_layout(self):
+ """Retrieves a list of options"""
+ options = self.options[:]
+ if "layout" in options:
+ options.remove("layout")
+ if options:
+ return [", ".join(options)]
+ return []
+
def options_cast(self, indent):
"""As above, but now casted to CLBlast data-types"""
if self.options:
@@ -670,14 +687,14 @@ class Routine:
def arguments_wrapper_cublas(self, flavour):
"""As above, but for the cuBLAS wrapper"""
- return (self.options_list() + self.sizes_list_as_int() +
- list(chain(*[self.buffer_wrapper_cublas(b) for b in self.scalar_buffers_first()])) +
- self.scalar_use_wrapper_by_ref("alpha", flavour) +
- list(chain(*[self.buffer_wrapper_cublas(b) for b in self.buffers_first()])) +
- self.scalar_use_wrapper_by_ref("beta", flavour) +
- list(chain(*[self.buffer_wrapper_cublas(b) for b in self.buffers_second()])) +
- list(chain(*[self.buffer_wrapper_cublas(b) for b in self.scalar_buffers_second()])) +
- list(chain(*[self.scalar_use_wrapper_by_ref(s, flavour) for s in self.other_scalars()])))
+ return (self.options_list_no_layout() + self.sizes_list_as_int() +
+ self.scalar_use_wrapper_cublas("alpha", flavour) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_use_wrapper_cublas("beta", flavour) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_first()])) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_use_wrapper_cublas(s, flavour) for s in self.other_scalars()])))
def arguments_def(self, flavour):
"""Retrieves a combination of all the argument definitions"""