summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/generator/generator/cpp.py1
-rw-r--r--scripts/generator/generator/routine.py12
2 files changed, 11 insertions, 2 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index 7c695dc8..79d6b2a1 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -319,6 +319,7 @@ def wrapper_cublas(routine):
# Calls the cuBLAS routine
result += " cublasHandle_t handle;" + NL
+ result += " if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) { return CUBLAS_STATUS_NOT_INITIALIZED; }" + NL
result += " auto status = cublas" + flavour.name_cublas() + routine.name + "(handle, "
result += ("," + NL + indent).join([a for a in arguments]) + ");" + NL
result += " cublasDestroy(handle);" + NL
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index b1db484f..a7abfde5 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -347,7 +347,12 @@ class Routine:
"""As above but for cuBLAS the wrapper"""
prefix = "const " if name in self.inputs else ""
if name in self.inputs or name in self.outputs:
- if flavour.precision_name in ["C", "Z"]:
+ if name in self.index_buffers():
+ a = ["reinterpret_cast<int*>(&" + name + "_buffer[" + name + "_offset])"]
+ elif name in self.outputs and flavour.name in ["Sc", "Dz"]:
+ dtype = "float" if flavour.name == "Sc" else "double"
+ a = ["reinterpret_cast<" + dtype + "*>(&" + name + "_buffer[" + name + "_offset])"]
+ elif flavour.precision_name in ["C", "Z"]:
cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex"
a = ["reinterpret_cast<" + prefix + cuda_complex + "*>" +
"(&" + name + "_buffer[" + name + "_offset])"]
@@ -358,7 +363,10 @@ class Routine:
c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"]
elif name in ["a", "b", "c"]:
c = [name + "_" + self.postfix(name)]
- return [", ".join(a + c)]
+ result = [", ".join(a + c)]
+ if self.name == "trmm" and name == "a":
+ result *= 2
+ return result
return []
def buffer_type(self, name):